seconohe.torch
1# -*- coding: utf-8 -*- 2# Copyright (c) 2025 Salvador E. Tropea 3# Copyright (c) 2025 Instituto Nacional de TecnologĂa Industrial 4# License: GPL-3.0 5# Project: SeCoNoHe 6# PyTorch helpers 7import contextlib # For context manager 8import logging 9import torch 10from typing import Optional, Iterator, cast 11try: 12 import comfy.model_management as mm 13 with_comfy = True 14except Exception: 15 with_comfy = False 16 17 18def get_torch_device_options() -> tuple[list[str], str]: 19 """ 20 Detects available PyTorch devices and returns a list and a suitable default. 21 22 Scans for CPU, CUDA devices, and MPS (for Apple Silicon), providing a list 23 of device strings (e.g., 'cpu', 'cuda', 'cuda:0') and a recommended default. 24 25 :return: A tuple containing the list of available device strings and the 26 recommended default device string. 27 :rtype: tuple[list[str], str] 28 """ 29 # We always have CPU 30 default = "cpu" 31 options = [default] 32 # Do we have CUDA? 33 if torch.cuda.is_available(): 34 default = "cuda" 35 options.append(default) 36 for i in range(torch.cuda.device_count()): 37 options.append(f"cuda:{i}") # Specific CUDA devices 38 # Is this a Mac? 39 if torch.backends.mps.is_available() and torch.backends.mps.is_built(): 40 options.append("mps") 41 if default == "cpu": 42 default = "mps" 43 return options, default 44 45 46def get_offload_device() -> torch.device: 47 """ 48 Gets the appropriate device for offloading models. 49 50 Uses `comfy.model_management.unet_offload_device()` if available in a ComfyUI 51 environment, otherwise defaults to the CPU. 52 53 :return: The torch.device object to use for offloading. 54 :rtype: torch.device 55 """ 56 if with_comfy: 57 return cast(torch.device, mm.unet_offload_device()) 58 return torch.device("cpu") 59 60 61def get_canonical_device(device: str | torch.device) -> torch.device: 62 """ 63 Converts a device string or object into a canonical torch.device object. 64 65 Ensures that a device string like 'cuda' is converted to its fully 66 indexed form, e.g., 'cuda:0', by checking the current default device. 67 68 :param device: The device identifier to canonicalize. 69 :type device: str | torch.device 70 :return: A torch.device object with an explicit index if applicable. 71 :rtype: torch.device 72 """ 73 if not isinstance(device, torch.device): 74 device = torch.device(device) 75 76 # If it's a CUDA device and no index is specified, get the default one. 77 if device.type == 'cuda' and device.index is None: 78 # NOTE: This adds a dependency on torch.cuda.current_device() 79 # The first solution is often better as it doesn't need this. 80 return torch.device(f'cuda:{torch.cuda.current_device()}') 81 return device 82 83 84# ################################################################################## 85# # Helper for inference (Target device, offload, eval, no_grad and cuDNN Benchmark) 86# ################################################################################## 87 88@contextlib.contextmanager 89def model_to_target(logger: logging.Logger, model: torch.nn.Module) -> Iterator[None]: 90 """ 91 A context manager for safe model inference. 92 93 Handles device placement, inference state (`eval()`, `no_grad()`), 94 optional cuDNN benchmark settings, and safe offloading in a `finally` block 95 to ensure resources are managed even if errors occur. 96 97 The model object is expected to have two optional custom attributes: 98 - ``target_device`` (torch.device): The device to run inference on. 99 - ``cudnn_benchmark_setting`` (bool): The desired cuDNN benchmark state. 100 101 Example Usage:: 102 103 with model_to_target(logger, my_model): 104 # Your inference code here 105 output = my_model(input_tensor) 106 107 :param logger: A logger instance for logging state changes. 108 :type logger: logging.Logger 109 :param model: The torch.nn.Module to manage. 110 :type model: torch.nn.Module 111 :yield: None 112 """ 113 if not isinstance(model, torch.nn.Module): 114 with torch.no_grad(): 115 yield # The code inside the 'with' statement runs here 116 return 117 118 # --- EXPLICIT TYPE ANNOTATIONS --- 119 # Tell mypy the expected types for these variables. 120 target_device: torch.device 121 original_device: torch.device 122 original_cudnn_benchmark_state: Optional[bool] = None # Default is to keep the current setting 123 124 # 1. Determine target device from the model object 125 # Use a hasattr check for robustness, and type hinting 126 if hasattr(model, 'target_device') and isinstance(model.target_device, torch.device): 127 # Mypy now understands that model.target_device exists and is a torch.device 128 target_device = model.target_device 129 else: 130 logger.warning("model_to_target: 'target_device' attribute not found or is not a torch.device on the model. " 131 "Defaulting to model's current device.") 132 # Ensure model has parameters before calling next() 133 try: 134 target_device = next(model.parameters()).device 135 except StopIteration: 136 logger.warning("model_to_target: Model has no parameters. Cannot determine device. Assuming CPU.") 137 target_device = torch.device("cpu") 138 139 # 2. Get CUDNN benchmark setting from the model object (optional) 140 # Use hasattr as this is an optional setting that not all models might have. 141 cudnn_benchmark_enabled: Optional[bool] = None 142 if hasattr(model, 'cudnn_benchmark_setting'): 143 setting = model.cudnn_benchmark_setting 144 if isinstance(setting, bool): 145 cudnn_benchmark_enabled = setting 146 else: 147 logger.warning(f"model.cudnn_benchmark_setting was not a bool, but {type(setting)}. Ignoring.") 148 149 try: 150 # Get original_device from model after ensuring it has parameters 151 original_device = next(model.parameters()).device 152 except StopIteration: 153 original_device = torch.device("cpu") # Match fallback from above 154 155 is_cuda_target = target_device.type == 'cuda' 156 157 try: 158 # 3. Manage cuDNN benchmark state 159 if (cudnn_benchmark_enabled is not None and is_cuda_target and hasattr(torch.backends, 'cudnn') and 160 torch.backends.cudnn.is_available()): 161 if torch.backends.cudnn.benchmark != cudnn_benchmark_enabled: 162 original_cudnn_benchmark_state = torch.backends.cudnn.benchmark 163 torch.backends.cudnn.benchmark = cudnn_benchmark_enabled 164 logger.debug(f"Temporarily set cuDNN benchmark to {torch.backends.cudnn.benchmark}") 165 166 # 4. Move model to target device if not already there 167 if original_device != target_device: 168 logger.debug(f"Moving model from `{original_device}` to target device `{target_device}` for inference.") 169 model.to(target_device) 170 171 # 5. Set to eval mode and disable gradients for the operation 172 model.eval() 173 with torch.no_grad(): 174 yield # The code inside the 'with' statement runs here 175 176 finally: 177 # 6. Restore original cuDNN benchmark state 178 if original_cudnn_benchmark_state is not None: 179 # This check is sufficient because it will only be not None if we set it inside the try block 180 torch.backends.cudnn.benchmark = original_cudnn_benchmark_state 181 logger.debug(f"Restored cuDNN benchmark to {original_cudnn_benchmark_state}") 182 183 # 7. Offload model back to CPU 184 if with_comfy: 185 offload_device = get_offload_device() 186 try: 187 current_device_after_yield = next(model.parameters()).device 188 if current_device_after_yield != offload_device: 189 logger.debug(f"Offloading model from `{current_device_after_yield}` to offload device `{offload_device}`.") 190 model.to(offload_device) 191 # Clear cache if we were on a CUDA device 192 if 'cuda' in str(current_device_after_yield): 193 torch.cuda.empty_cache() 194 except StopIteration: 195 pass # Model with no parameters doesn't need offloading.
19def get_torch_device_options() -> tuple[list[str], str]: 20 """ 21 Detects available PyTorch devices and returns a list and a suitable default. 22 23 Scans for CPU, CUDA devices, and MPS (for Apple Silicon), providing a list 24 of device strings (e.g., 'cpu', 'cuda', 'cuda:0') and a recommended default. 25 26 :return: A tuple containing the list of available device strings and the 27 recommended default device string. 28 :rtype: tuple[list[str], str] 29 """ 30 # We always have CPU 31 default = "cpu" 32 options = [default] 33 # Do we have CUDA? 34 if torch.cuda.is_available(): 35 default = "cuda" 36 options.append(default) 37 for i in range(torch.cuda.device_count()): 38 options.append(f"cuda:{i}") # Specific CUDA devices 39 # Is this a Mac? 40 if torch.backends.mps.is_available() and torch.backends.mps.is_built(): 41 options.append("mps") 42 if default == "cpu": 43 default = "mps" 44 return options, default
Detects available PyTorch devices and returns a list and a suitable default.
Scans for CPU, CUDA devices, and MPS (for Apple Silicon), providing a list of device strings (e.g., 'cpu', 'cuda', 'cuda:0') and a recommended default.
Returns
A tuple containing the list of available device strings and the recommended default device string.
47def get_offload_device() -> torch.device: 48 """ 49 Gets the appropriate device for offloading models. 50 51 Uses `comfy.model_management.unet_offload_device()` if available in a ComfyUI 52 environment, otherwise defaults to the CPU. 53 54 :return: The torch.device object to use for offloading. 55 :rtype: torch.device 56 """ 57 if with_comfy: 58 return cast(torch.device, mm.unet_offload_device()) 59 return torch.device("cpu")
Gets the appropriate device for offloading models.
Uses comfy.model_management.unet_offload_device()
if available in a ComfyUI
environment, otherwise defaults to the CPU.
Returns
The torch.device object to use for offloading.
62def get_canonical_device(device: str | torch.device) -> torch.device: 63 """ 64 Converts a device string or object into a canonical torch.device object. 65 66 Ensures that a device string like 'cuda' is converted to its fully 67 indexed form, e.g., 'cuda:0', by checking the current default device. 68 69 :param device: The device identifier to canonicalize. 70 :type device: str | torch.device 71 :return: A torch.device object with an explicit index if applicable. 72 :rtype: torch.device 73 """ 74 if not isinstance(device, torch.device): 75 device = torch.device(device) 76 77 # If it's a CUDA device and no index is specified, get the default one. 78 if device.type == 'cuda' and device.index is None: 79 # NOTE: This adds a dependency on torch.cuda.current_device() 80 # The first solution is often better as it doesn't need this. 81 return torch.device(f'cuda:{torch.cuda.current_device()}') 82 return device
Converts a device string or object into a canonical torch.device object.
Ensures that a device string like 'cuda' is converted to its fully indexed form, e.g., 'cuda:0', by checking the current default device.
Parameters
- device: The device identifier to canonicalize.
Returns
A torch.device object with an explicit index if applicable.
89@contextlib.contextmanager 90def model_to_target(logger: logging.Logger, model: torch.nn.Module) -> Iterator[None]: 91 """ 92 A context manager for safe model inference. 93 94 Handles device placement, inference state (`eval()`, `no_grad()`), 95 optional cuDNN benchmark settings, and safe offloading in a `finally` block 96 to ensure resources are managed even if errors occur. 97 98 The model object is expected to have two optional custom attributes: 99 - ``target_device`` (torch.device): The device to run inference on. 100 - ``cudnn_benchmark_setting`` (bool): The desired cuDNN benchmark state. 101 102 Example Usage:: 103 104 with model_to_target(logger, my_model): 105 # Your inference code here 106 output = my_model(input_tensor) 107 108 :param logger: A logger instance for logging state changes. 109 :type logger: logging.Logger 110 :param model: The torch.nn.Module to manage. 111 :type model: torch.nn.Module 112 :yield: None 113 """ 114 if not isinstance(model, torch.nn.Module): 115 with torch.no_grad(): 116 yield # The code inside the 'with' statement runs here 117 return 118 119 # --- EXPLICIT TYPE ANNOTATIONS --- 120 # Tell mypy the expected types for these variables. 121 target_device: torch.device 122 original_device: torch.device 123 original_cudnn_benchmark_state: Optional[bool] = None # Default is to keep the current setting 124 125 # 1. Determine target device from the model object 126 # Use a hasattr check for robustness, and type hinting 127 if hasattr(model, 'target_device') and isinstance(model.target_device, torch.device): 128 # Mypy now understands that model.target_device exists and is a torch.device 129 target_device = model.target_device 130 else: 131 logger.warning("model_to_target: 'target_device' attribute not found or is not a torch.device on the model. " 132 "Defaulting to model's current device.") 133 # Ensure model has parameters before calling next() 134 try: 135 target_device = next(model.parameters()).device 136 except StopIteration: 137 logger.warning("model_to_target: Model has no parameters. Cannot determine device. Assuming CPU.") 138 target_device = torch.device("cpu") 139 140 # 2. Get CUDNN benchmark setting from the model object (optional) 141 # Use hasattr as this is an optional setting that not all models might have. 142 cudnn_benchmark_enabled: Optional[bool] = None 143 if hasattr(model, 'cudnn_benchmark_setting'): 144 setting = model.cudnn_benchmark_setting 145 if isinstance(setting, bool): 146 cudnn_benchmark_enabled = setting 147 else: 148 logger.warning(f"model.cudnn_benchmark_setting was not a bool, but {type(setting)}. Ignoring.") 149 150 try: 151 # Get original_device from model after ensuring it has parameters 152 original_device = next(model.parameters()).device 153 except StopIteration: 154 original_device = torch.device("cpu") # Match fallback from above 155 156 is_cuda_target = target_device.type == 'cuda' 157 158 try: 159 # 3. Manage cuDNN benchmark state 160 if (cudnn_benchmark_enabled is not None and is_cuda_target and hasattr(torch.backends, 'cudnn') and 161 torch.backends.cudnn.is_available()): 162 if torch.backends.cudnn.benchmark != cudnn_benchmark_enabled: 163 original_cudnn_benchmark_state = torch.backends.cudnn.benchmark 164 torch.backends.cudnn.benchmark = cudnn_benchmark_enabled 165 logger.debug(f"Temporarily set cuDNN benchmark to {torch.backends.cudnn.benchmark}") 166 167 # 4. Move model to target device if not already there 168 if original_device != target_device: 169 logger.debug(f"Moving model from `{original_device}` to target device `{target_device}` for inference.") 170 model.to(target_device) 171 172 # 5. Set to eval mode and disable gradients for the operation 173 model.eval() 174 with torch.no_grad(): 175 yield # The code inside the 'with' statement runs here 176 177 finally: 178 # 6. Restore original cuDNN benchmark state 179 if original_cudnn_benchmark_state is not None: 180 # This check is sufficient because it will only be not None if we set it inside the try block 181 torch.backends.cudnn.benchmark = original_cudnn_benchmark_state 182 logger.debug(f"Restored cuDNN benchmark to {original_cudnn_benchmark_state}") 183 184 # 7. Offload model back to CPU 185 if with_comfy: 186 offload_device = get_offload_device() 187 try: 188 current_device_after_yield = next(model.parameters()).device 189 if current_device_after_yield != offload_device: 190 logger.debug(f"Offloading model from `{current_device_after_yield}` to offload device `{offload_device}`.") 191 model.to(offload_device) 192 # Clear cache if we were on a CUDA device 193 if 'cuda' in str(current_device_after_yield): 194 torch.cuda.empty_cache() 195 except StopIteration: 196 pass # Model with no parameters doesn't need offloading.
A context manager for safe model inference.
Handles device placement, inference state (eval()
, no_grad()
),
optional cuDNN benchmark settings, and safe offloading in a finally
block
to ensure resources are managed even if errors occur.
The model object is expected to have two optional custom attributes:
target_device
(torch.device): The device to run inference on.cudnn_benchmark_setting
(bool): The desired cuDNN benchmark state.
Example Usage::
with model_to_target(logger, my_model):
# Your inference code here
output = my_model(input_tensor)
Parameters
- logger: A logger instance for logging state changes.
- model: The torch.nn.Module to manage. :yield: None