seconohe.torch
1# -*- coding: utf-8 -*- 2# Copyright (c) 2025 Salvador E. Tropea 3# Copyright (c) 2025 Instituto Nacional de TecnologĂa Industrial 4# License: GPL-3.0 5# Project: SeCoNoHe 6# PyTorch helpers 7import contextlib # For context manager 8import logging 9import torch 10from typing import Optional, Iterator, cast, Union 11try: 12 import comfy.model_management as mm 13 with_comfy = True 14except Exception: 15 with_comfy = False 16from .misc import format_bytes 17from .logger import get_debug_level 18 19 20def get_torch_device_options(with_auto: Optional[bool] = False) -> tuple[list[str], str]: 21 """ 22 Detects available PyTorch devices and returns a list and a suitable default. 23 24 Scans for CPU, CUDA devices, and MPS (for Apple Silicon), providing a list 25 of device strings (e.g., 'cpu', 'cuda', 'cuda:0') and a recommended default. 26 27 :param with_auto: The first option is 'AUTO', intended to use the current ComfyUI device 28 :type with_auto: bool (optional) 29 30 :return: A tuple containing the list of available device strings and the 31 recommended default device string. 32 :rtype: tuple[list[str], str] 33 """ 34 # We always have CPU 35 default = "cpu" 36 options = [default] 37 # Do we have CUDA? 38 if torch.cuda.is_available(): 39 default = "cuda" 40 options.append(default) 41 for i in range(torch.cuda.device_count()): 42 options.append(f"cuda:{i}") # Specific CUDA devices 43 # Is this a Mac? 44 if torch.backends.mps.is_available() and torch.backends.mps.is_built(): 45 options.append("mps") 46 if default == "cpu": 47 default = "mps" 48 # AUTO means: the one used by ComfyUI 49 if with_auto: 50 options.insert(0, "AUTO") 51 return options, default 52 53 54def get_default_comfy_device() -> torch.device: 55 return get_canonical_device(mm.get_torch_device() if with_comfy else torch.cuda.current_device()) 56 57 58def get_offload_device() -> torch.device: 59 """ 60 Gets the appropriate device for offloading models. 61 62 Uses `comfy.model_management.unet_offload_device()` if available in a ComfyUI 63 environment, otherwise defaults to the CPU. 64 65 :return: The torch.device object to use for offloading. 66 :rtype: torch.device 67 """ 68 if with_comfy: 69 return cast(torch.device, mm.unet_offload_device()) 70 return torch.device("cpu") 71 72 73def get_canonical_device(device: str | torch.device) -> torch.device: 74 """ 75 Converts a device string or object into a canonical torch.device object. 76 77 Ensures that a device string like 'cuda' is converted to its fully 78 indexed form, e.g., 'cuda:0', by checking the current default device. 79 80 :param device: The device identifier to canonicalize. 81 :type device: str | torch.device 82 :return: A torch.device object with an explicit index if applicable. 83 :rtype: torch.device 84 """ 85 if not isinstance(device, torch.device): 86 device = torch.device(device) 87 88 # If it's a CUDA device and no index is specified, get the default one. 89 if device.type == 'cuda' and device.index is None: 90 # NOTE: This adds a dependency on torch.cuda.current_device() 91 # The first solution is often better as it doesn't need this. 92 return torch.device(f'cuda:{torch.cuda.current_device()}') 93 return device 94 95 96def get_pytorch_memory_usage_str(device: Optional[Union[int, torch.device]] = None) -> str: 97 """ 98 Returns a formatted string detailing PyTorch's current and peak memory usage on a given CUDA device. 99 100 Args: 101 device (Optional[Union[int, torch.device]]): The CUDA device to query. 102 If None, uses the current device. 103 104 Returns: 105 str: A formatted, multi-line string with memory usage details. 106 """ 107 if not torch.cuda.is_available(): 108 return "CUDA is not available. No GPU memory to report." 109 110 # Resolve the device 111 if device is None: 112 device = get_default_comfy_device() 113 114 # Get memory stats 115 allocated = torch.cuda.memory_allocated(device) 116 reserved = torch.cuda.memory_reserved(device) 117 peak_allocated = torch.cuda.max_memory_allocated(device) 118 peak_reserved = torch.cuda.max_memory_reserved(device) 119 120 # Format into a string 121 return ( 122 f"Allocated/Reserved: {format_bytes(allocated):>10s}/{format_bytes(reserved):>10s}" 123 f" (Peak: {format_bytes(peak_allocated):>10s}/{format_bytes(peak_reserved):>10s})\n" 124 ) 125 126 127# ################################################################################## 128# # Helper for inference (Target device, offload, eval, no_grad and cuDNN Benchmark) 129# ################################################################################## 130 131@contextlib.contextmanager 132def model_to_target(logger: logging.Logger, model: torch.nn.Module) -> Iterator[None]: 133 """ 134 A context manager for safe model inference. 135 136 Handles device placement, inference state (`eval()`, `no_grad()`), 137 optional cuDNN benchmark settings, and safe offloading in a `finally` block 138 to ensure resources are managed even if errors occur. 139 140 The model object is expected to have two optional custom attributes: 141 - ``target_device`` (torch.device): The device to run inference on. 142 - ``cudnn_benchmark_setting`` (bool): The desired cuDNN benchmark state. 143 144 Example Usage:: 145 146 with model_to_target(logger, my_model): 147 # Your inference code here 148 output = my_model(input_tensor) 149 150 :param logger: A logger instance for logging state changes. 151 :type logger: logging.Logger 152 :param model: The torch.nn.Module to manage. 153 :type model: torch.nn.Module 154 :yield: None 155 """ 156 if not isinstance(model, torch.nn.Module): 157 with torch.no_grad(): 158 yield # The code inside the 'with' statement runs here 159 return 160 161 # --- EXPLICIT TYPE ANNOTATIONS --- 162 # Tell mypy the expected types for these variables. 163 target_device: torch.device 164 original_device: torch.device 165 original_cudnn_benchmark_state: Optional[bool] = None # Default is to keep the current setting 166 167 # 1. Determine target device from the model object 168 # Use a hasattr check for robustness, and type hinting 169 if hasattr(model, 'target_device') and isinstance(model.target_device, torch.device): 170 # Mypy now understands that model.target_device exists and is a torch.device 171 target_device = model.target_device 172 else: 173 logger.warning("model_to_target: 'target_device' attribute not found or is not a torch.device on the model. " 174 "Defaulting to model's current device.") 175 # Ensure model has parameters before calling next() 176 try: 177 target_device = next(model.parameters()).device 178 except StopIteration: 179 logger.warning("model_to_target: Model has no parameters. Cannot determine device. Assuming CPU.") 180 target_device = torch.device("cpu") 181 182 # 2. Get CUDNN benchmark setting from the model object (optional) 183 # Use hasattr as this is an optional setting that not all models might have. 184 cudnn_benchmark_enabled: Optional[bool] = None 185 if hasattr(model, 'cudnn_benchmark_setting'): 186 setting = model.cudnn_benchmark_setting 187 if isinstance(setting, bool): 188 cudnn_benchmark_enabled = setting 189 else: 190 logger.warning(f"model.cudnn_benchmark_setting was not a bool, but {type(setting)}. Ignoring.") 191 192 try: 193 # Get original_device from model after ensuring it has parameters 194 original_device = next(model.parameters()).device 195 except StopIteration: 196 original_device = torch.device("cpu") # Match fallback from above 197 198 is_cuda_target = target_device.type == 'cuda' 199 200 try: 201 # 3. Manage cuDNN benchmark state 202 if (cudnn_benchmark_enabled is not None and is_cuda_target and hasattr(torch.backends, 'cudnn') and 203 torch.backends.cudnn.is_available()): 204 if torch.backends.cudnn.benchmark != cudnn_benchmark_enabled: 205 original_cudnn_benchmark_state = torch.backends.cudnn.benchmark 206 torch.backends.cudnn.benchmark = cudnn_benchmark_enabled 207 logger.debug(f"Temporarily set cuDNN benchmark to {torch.backends.cudnn.benchmark}") 208 209 # 4. Move model to target device if not already there 210 if original_device != target_device: 211 logger.debug(f"Moving model from `{original_device}` to target device `{target_device}` for inference.") 212 model.to(target_device) 213 214 # 5. Set to eval mode and disable gradients for the operation 215 model.eval() 216 with torch.no_grad(): 217 yield # The code inside the 'with' statement runs here 218 219 finally: 220 # 6. Restore original cuDNN benchmark state 221 if original_cudnn_benchmark_state is not None: 222 # This check is sufficient because it will only be not None if we set it inside the try block 223 torch.backends.cudnn.benchmark = original_cudnn_benchmark_state 224 logger.debug(f"Restored cuDNN benchmark to {original_cudnn_benchmark_state}") 225 226 # 7. Offload model back to CPU 227 if with_comfy: 228 offload_device = get_offload_device() 229 try: 230 current_device_after_yield = next(model.parameters()).device 231 if current_device_after_yield != offload_device: 232 logger.debug(f"Offloading model from `{current_device_after_yield}` to offload device `{offload_device}`.") 233 model.to(offload_device) 234 # Clear cache if we were on a CUDA device 235 if 'cuda' in str(current_device_after_yield): 236 torch.cuda.empty_cache() 237 except StopIteration: 238 pass # Model with no parameters doesn't need offloading. 239 240 241# ################################################################################## 242# # Helper for profiling 243# ################################################################################## 244 245 246class TorchProfile(object): 247 def __init__(self, logger: logging.Logger, level: int, msg: str, device: Optional[torch.device] = None): 248 """ 249 Used to measure the time and VRAM used for a CUDA task. 250 251 :param logger: The logger instance to display the information. 252 :type logger: logging.Logger 253 :param level: The required verbosity level to display this information. 254 :type level: int 255 :param msg: The message string to display along with the information. 256 :type msg: str 257 :param device: CUDA device, otherwise we use the default 258 :type device: torch.device (optional) 259 """ 260 super().__init__() 261 self.enabled = False 262 if get_debug_level(logger) < level: 263 return 264 if not torch.cuda.is_available(): 265 logger.debug("CUDA is not available. No GPU memory/timing to report.") 266 return 267 if device is None: 268 device = get_default_comfy_device() 269 self.enabled = True 270 self.logger = logger 271 self.device = device 272 self.msg = msg 273 self.start_event = torch.cuda.Event(enable_timing=True) 274 self.end_event = torch.cuda.Event(enable_timing=True) 275 self.start() 276 277 def start(self): 278 """ Start to profile. Automatically invoked by the constructor """ 279 if not self.enabled: 280 return 281 self.logger.debug(f"Starting {self.msg} on {self.device}") 282 torch.cuda.reset_peak_memory_stats(self.device) 283 self.base_vram = torch.cuda.memory_allocated(self.device) / (1024 ** 2) # in MB 284 self.start_event.record() 285 286 def end(self): 287 """ Stop profiling and show the results """ 288 if not self.enabled: 289 return 290 # Stop timer 291 self.end_event.record() 292 torch.cuda.synchronize(self.device) 293 time = self.start_event.elapsed_time(self.end_event) 294 # Get peak used memory 295 mem_peak = torch.cuda.max_memory_allocated(self.device) / (1024 ** 2) # in MB 296 # Show it 297 self.logger.debug(f"Finished {self.msg} on {self.device}: {time:.6f} ms, peak {mem_peak:.2f} MiB" 298 f" (+{mem_peak-self.base_vram:.2f} MiB)")
21def get_torch_device_options(with_auto: Optional[bool] = False) -> tuple[list[str], str]: 22 """ 23 Detects available PyTorch devices and returns a list and a suitable default. 24 25 Scans for CPU, CUDA devices, and MPS (for Apple Silicon), providing a list 26 of device strings (e.g., 'cpu', 'cuda', 'cuda:0') and a recommended default. 27 28 :param with_auto: The first option is 'AUTO', intended to use the current ComfyUI device 29 :type with_auto: bool (optional) 30 31 :return: A tuple containing the list of available device strings and the 32 recommended default device string. 33 :rtype: tuple[list[str], str] 34 """ 35 # We always have CPU 36 default = "cpu" 37 options = [default] 38 # Do we have CUDA? 39 if torch.cuda.is_available(): 40 default = "cuda" 41 options.append(default) 42 for i in range(torch.cuda.device_count()): 43 options.append(f"cuda:{i}") # Specific CUDA devices 44 # Is this a Mac? 45 if torch.backends.mps.is_available() and torch.backends.mps.is_built(): 46 options.append("mps") 47 if default == "cpu": 48 default = "mps" 49 # AUTO means: the one used by ComfyUI 50 if with_auto: 51 options.insert(0, "AUTO") 52 return options, default
Detects available PyTorch devices and returns a list and a suitable default.
Scans for CPU, CUDA devices, and MPS (for Apple Silicon), providing a list of device strings (e.g., 'cpu', 'cuda', 'cuda:0') and a recommended default.
Parameters
- with_auto: The first option is 'AUTO', intended to use the current ComfyUI device
Returns
A tuple containing the list of available device strings and the recommended default device string.
59def get_offload_device() -> torch.device: 60 """ 61 Gets the appropriate device for offloading models. 62 63 Uses `comfy.model_management.unet_offload_device()` if available in a ComfyUI 64 environment, otherwise defaults to the CPU. 65 66 :return: The torch.device object to use for offloading. 67 :rtype: torch.device 68 """ 69 if with_comfy: 70 return cast(torch.device, mm.unet_offload_device()) 71 return torch.device("cpu")
Gets the appropriate device for offloading models.
Uses comfy.model_management.unet_offload_device() if available in a ComfyUI
environment, otherwise defaults to the CPU.
Returns
The torch.device object to use for offloading.
74def get_canonical_device(device: str | torch.device) -> torch.device: 75 """ 76 Converts a device string or object into a canonical torch.device object. 77 78 Ensures that a device string like 'cuda' is converted to its fully 79 indexed form, e.g., 'cuda:0', by checking the current default device. 80 81 :param device: The device identifier to canonicalize. 82 :type device: str | torch.device 83 :return: A torch.device object with an explicit index if applicable. 84 :rtype: torch.device 85 """ 86 if not isinstance(device, torch.device): 87 device = torch.device(device) 88 89 # If it's a CUDA device and no index is specified, get the default one. 90 if device.type == 'cuda' and device.index is None: 91 # NOTE: This adds a dependency on torch.cuda.current_device() 92 # The first solution is often better as it doesn't need this. 93 return torch.device(f'cuda:{torch.cuda.current_device()}') 94 return device
Converts a device string or object into a canonical torch.device object.
Ensures that a device string like 'cuda' is converted to its fully indexed form, e.g., 'cuda:0', by checking the current default device.
Parameters
- device: The device identifier to canonicalize.
Returns
A torch.device object with an explicit index if applicable.
97def get_pytorch_memory_usage_str(device: Optional[Union[int, torch.device]] = None) -> str: 98 """ 99 Returns a formatted string detailing PyTorch's current and peak memory usage on a given CUDA device. 100 101 Args: 102 device (Optional[Union[int, torch.device]]): The CUDA device to query. 103 If None, uses the current device. 104 105 Returns: 106 str: A formatted, multi-line string with memory usage details. 107 """ 108 if not torch.cuda.is_available(): 109 return "CUDA is not available. No GPU memory to report." 110 111 # Resolve the device 112 if device is None: 113 device = get_default_comfy_device() 114 115 # Get memory stats 116 allocated = torch.cuda.memory_allocated(device) 117 reserved = torch.cuda.memory_reserved(device) 118 peak_allocated = torch.cuda.max_memory_allocated(device) 119 peak_reserved = torch.cuda.max_memory_reserved(device) 120 121 # Format into a string 122 return ( 123 f"Allocated/Reserved: {format_bytes(allocated):>10s}/{format_bytes(reserved):>10s}" 124 f" (Peak: {format_bytes(peak_allocated):>10s}/{format_bytes(peak_reserved):>10s})\n" 125 )
Returns a formatted string detailing PyTorch's current and peak memory usage on a given CUDA device.
Args: device (Optional[Union[int, torch.device]]): The CUDA device to query. If None, uses the current device.
Returns: str: A formatted, multi-line string with memory usage details.
132@contextlib.contextmanager 133def model_to_target(logger: logging.Logger, model: torch.nn.Module) -> Iterator[None]: 134 """ 135 A context manager for safe model inference. 136 137 Handles device placement, inference state (`eval()`, `no_grad()`), 138 optional cuDNN benchmark settings, and safe offloading in a `finally` block 139 to ensure resources are managed even if errors occur. 140 141 The model object is expected to have two optional custom attributes: 142 - ``target_device`` (torch.device): The device to run inference on. 143 - ``cudnn_benchmark_setting`` (bool): The desired cuDNN benchmark state. 144 145 Example Usage:: 146 147 with model_to_target(logger, my_model): 148 # Your inference code here 149 output = my_model(input_tensor) 150 151 :param logger: A logger instance for logging state changes. 152 :type logger: logging.Logger 153 :param model: The torch.nn.Module to manage. 154 :type model: torch.nn.Module 155 :yield: None 156 """ 157 if not isinstance(model, torch.nn.Module): 158 with torch.no_grad(): 159 yield # The code inside the 'with' statement runs here 160 return 161 162 # --- EXPLICIT TYPE ANNOTATIONS --- 163 # Tell mypy the expected types for these variables. 164 target_device: torch.device 165 original_device: torch.device 166 original_cudnn_benchmark_state: Optional[bool] = None # Default is to keep the current setting 167 168 # 1. Determine target device from the model object 169 # Use a hasattr check for robustness, and type hinting 170 if hasattr(model, 'target_device') and isinstance(model.target_device, torch.device): 171 # Mypy now understands that model.target_device exists and is a torch.device 172 target_device = model.target_device 173 else: 174 logger.warning("model_to_target: 'target_device' attribute not found or is not a torch.device on the model. " 175 "Defaulting to model's current device.") 176 # Ensure model has parameters before calling next() 177 try: 178 target_device = next(model.parameters()).device 179 except StopIteration: 180 logger.warning("model_to_target: Model has no parameters. Cannot determine device. Assuming CPU.") 181 target_device = torch.device("cpu") 182 183 # 2. Get CUDNN benchmark setting from the model object (optional) 184 # Use hasattr as this is an optional setting that not all models might have. 185 cudnn_benchmark_enabled: Optional[bool] = None 186 if hasattr(model, 'cudnn_benchmark_setting'): 187 setting = model.cudnn_benchmark_setting 188 if isinstance(setting, bool): 189 cudnn_benchmark_enabled = setting 190 else: 191 logger.warning(f"model.cudnn_benchmark_setting was not a bool, but {type(setting)}. Ignoring.") 192 193 try: 194 # Get original_device from model after ensuring it has parameters 195 original_device = next(model.parameters()).device 196 except StopIteration: 197 original_device = torch.device("cpu") # Match fallback from above 198 199 is_cuda_target = target_device.type == 'cuda' 200 201 try: 202 # 3. Manage cuDNN benchmark state 203 if (cudnn_benchmark_enabled is not None and is_cuda_target and hasattr(torch.backends, 'cudnn') and 204 torch.backends.cudnn.is_available()): 205 if torch.backends.cudnn.benchmark != cudnn_benchmark_enabled: 206 original_cudnn_benchmark_state = torch.backends.cudnn.benchmark 207 torch.backends.cudnn.benchmark = cudnn_benchmark_enabled 208 logger.debug(f"Temporarily set cuDNN benchmark to {torch.backends.cudnn.benchmark}") 209 210 # 4. Move model to target device if not already there 211 if original_device != target_device: 212 logger.debug(f"Moving model from `{original_device}` to target device `{target_device}` for inference.") 213 model.to(target_device) 214 215 # 5. Set to eval mode and disable gradients for the operation 216 model.eval() 217 with torch.no_grad(): 218 yield # The code inside the 'with' statement runs here 219 220 finally: 221 # 6. Restore original cuDNN benchmark state 222 if original_cudnn_benchmark_state is not None: 223 # This check is sufficient because it will only be not None if we set it inside the try block 224 torch.backends.cudnn.benchmark = original_cudnn_benchmark_state 225 logger.debug(f"Restored cuDNN benchmark to {original_cudnn_benchmark_state}") 226 227 # 7. Offload model back to CPU 228 if with_comfy: 229 offload_device = get_offload_device() 230 try: 231 current_device_after_yield = next(model.parameters()).device 232 if current_device_after_yield != offload_device: 233 logger.debug(f"Offloading model from `{current_device_after_yield}` to offload device `{offload_device}`.") 234 model.to(offload_device) 235 # Clear cache if we were on a CUDA device 236 if 'cuda' in str(current_device_after_yield): 237 torch.cuda.empty_cache() 238 except StopIteration: 239 pass # Model with no parameters doesn't need offloading.
A context manager for safe model inference.
Handles device placement, inference state (eval(), no_grad()),
optional cuDNN benchmark settings, and safe offloading in a finally block
to ensure resources are managed even if errors occur.
The model object is expected to have two optional custom attributes:
target_device(torch.device): The device to run inference on.cudnn_benchmark_setting(bool): The desired cuDNN benchmark state.
Example Usage::
with model_to_target(logger, my_model):
# Your inference code here
output = my_model(input_tensor)
Parameters
- logger: A logger instance for logging state changes.
- model: The torch.nn.Module to manage. :yield: None
247class TorchProfile(object): 248 def __init__(self, logger: logging.Logger, level: int, msg: str, device: Optional[torch.device] = None): 249 """ 250 Used to measure the time and VRAM used for a CUDA task. 251 252 :param logger: The logger instance to display the information. 253 :type logger: logging.Logger 254 :param level: The required verbosity level to display this information. 255 :type level: int 256 :param msg: The message string to display along with the information. 257 :type msg: str 258 :param device: CUDA device, otherwise we use the default 259 :type device: torch.device (optional) 260 """ 261 super().__init__() 262 self.enabled = False 263 if get_debug_level(logger) < level: 264 return 265 if not torch.cuda.is_available(): 266 logger.debug("CUDA is not available. No GPU memory/timing to report.") 267 return 268 if device is None: 269 device = get_default_comfy_device() 270 self.enabled = True 271 self.logger = logger 272 self.device = device 273 self.msg = msg 274 self.start_event = torch.cuda.Event(enable_timing=True) 275 self.end_event = torch.cuda.Event(enable_timing=True) 276 self.start() 277 278 def start(self): 279 """ Start to profile. Automatically invoked by the constructor """ 280 if not self.enabled: 281 return 282 self.logger.debug(f"Starting {self.msg} on {self.device}") 283 torch.cuda.reset_peak_memory_stats(self.device) 284 self.base_vram = torch.cuda.memory_allocated(self.device) / (1024 ** 2) # in MB 285 self.start_event.record() 286 287 def end(self): 288 """ Stop profiling and show the results """ 289 if not self.enabled: 290 return 291 # Stop timer 292 self.end_event.record() 293 torch.cuda.synchronize(self.device) 294 time = self.start_event.elapsed_time(self.end_event) 295 # Get peak used memory 296 mem_peak = torch.cuda.max_memory_allocated(self.device) / (1024 ** 2) # in MB 297 # Show it 298 self.logger.debug(f"Finished {self.msg} on {self.device}: {time:.6f} ms, peak {mem_peak:.2f} MiB" 299 f" (+{mem_peak-self.base_vram:.2f} MiB)")
248 def __init__(self, logger: logging.Logger, level: int, msg: str, device: Optional[torch.device] = None): 249 """ 250 Used to measure the time and VRAM used for a CUDA task. 251 252 :param logger: The logger instance to display the information. 253 :type logger: logging.Logger 254 :param level: The required verbosity level to display this information. 255 :type level: int 256 :param msg: The message string to display along with the information. 257 :type msg: str 258 :param device: CUDA device, otherwise we use the default 259 :type device: torch.device (optional) 260 """ 261 super().__init__() 262 self.enabled = False 263 if get_debug_level(logger) < level: 264 return 265 if not torch.cuda.is_available(): 266 logger.debug("CUDA is not available. No GPU memory/timing to report.") 267 return 268 if device is None: 269 device = get_default_comfy_device() 270 self.enabled = True 271 self.logger = logger 272 self.device = device 273 self.msg = msg 274 self.start_event = torch.cuda.Event(enable_timing=True) 275 self.end_event = torch.cuda.Event(enable_timing=True) 276 self.start()
Used to measure the time and VRAM used for a CUDA task.
Parameters
- logger: The logger instance to display the information.
- level: The required verbosity level to display this information.
- msg: The message string to display along with the information.
- device: CUDA device, otherwise we use the default
278 def start(self): 279 """ Start to profile. Automatically invoked by the constructor """ 280 if not self.enabled: 281 return 282 self.logger.debug(f"Starting {self.msg} on {self.device}") 283 torch.cuda.reset_peak_memory_stats(self.device) 284 self.base_vram = torch.cuda.memory_allocated(self.device) / (1024 ** 2) # in MB 285 self.start_event.record()
Start to profile. Automatically invoked by the constructor
287 def end(self): 288 """ Stop profiling and show the results """ 289 if not self.enabled: 290 return 291 # Stop timer 292 self.end_event.record() 293 torch.cuda.synchronize(self.device) 294 time = self.start_event.elapsed_time(self.end_event) 295 # Get peak used memory 296 mem_peak = torch.cuda.max_memory_allocated(self.device) / (1024 ** 2) # in MB 297 # Show it 298 self.logger.debug(f"Finished {self.msg} on {self.device}: {time:.6f} ms, peak {mem_peak:.2f} MiB" 299 f" (+{mem_peak-self.base_vram:.2f} MiB)")
Stop profiling and show the results