seconohe.torch

  1# -*- coding: utf-8 -*-
  2# Copyright (c) 2025 Salvador E. Tropea
  3# Copyright (c) 2025 Instituto Nacional de TecnologĂ­a Industrial
  4# License: GPL-3.0
  5# Project: SeCoNoHe
  6# PyTorch helpers
  7import contextlib  # For context manager
  8import logging
  9import torch
 10from typing import Optional, Iterator, cast, Union
 11try:
 12    import comfy.model_management as mm
 13    with_comfy = True
 14except Exception:
 15    with_comfy = False
 16from .misc import format_bytes
 17from .logger import get_debug_level
 18
 19
 20def get_torch_device_options(with_auto: Optional[bool] = False) -> tuple[list[str], str]:
 21    """
 22    Detects available PyTorch devices and returns a list and a suitable default.
 23
 24    Scans for CPU, CUDA devices, and MPS (for Apple Silicon), providing a list
 25    of device strings (e.g., 'cpu', 'cuda', 'cuda:0') and a recommended default.
 26
 27    :param with_auto: The first option is 'AUTO', intended to use the current ComfyUI device
 28    :type with_auto: bool (optional)
 29
 30    :return: A tuple containing the list of available device strings and the
 31             recommended default device string.
 32    :rtype: tuple[list[str], str]
 33    """
 34    # We always have CPU
 35    default = "cpu"
 36    options = [default]
 37    # Do we have CUDA?
 38    if torch.cuda.is_available():
 39        default = "cuda"
 40        options.append(default)
 41        for i in range(torch.cuda.device_count()):
 42            options.append(f"cuda:{i}")  # Specific CUDA devices
 43    # Is this a Mac?
 44    if torch.backends.mps.is_available() and torch.backends.mps.is_built():
 45        options.append("mps")
 46        if default == "cpu":
 47            default = "mps"
 48    # AUTO means: the one used by ComfyUI
 49    if with_auto:
 50        options.insert(0, "AUTO")
 51    return options, default
 52
 53
 54def get_default_comfy_device() -> torch.device:
 55    return get_canonical_device(mm.get_torch_device() if with_comfy else torch.cuda.current_device())
 56
 57
 58def get_offload_device() -> torch.device:
 59    """
 60    Gets the appropriate device for offloading models.
 61
 62    Uses `comfy.model_management.unet_offload_device()` if available in a ComfyUI
 63    environment, otherwise defaults to the CPU.
 64
 65    :return: The torch.device object to use for offloading.
 66    :rtype: torch.device
 67    """
 68    if with_comfy:
 69        return cast(torch.device, mm.unet_offload_device())
 70    return torch.device("cpu")
 71
 72
 73def get_canonical_device(device: str | torch.device) -> torch.device:
 74    """
 75    Converts a device string or object into a canonical torch.device object.
 76
 77    Ensures that a device string like 'cuda' is converted to its fully
 78    indexed form, e.g., 'cuda:0', by checking the current default device.
 79
 80    :param device: The device identifier to canonicalize.
 81    :type device: str | torch.device
 82    :return: A torch.device object with an explicit index if applicable.
 83    :rtype: torch.device
 84    """
 85    if not isinstance(device, torch.device):
 86        device = torch.device(device)
 87
 88    # If it's a CUDA device and no index is specified, get the default one.
 89    if device.type == 'cuda' and device.index is None:
 90        # NOTE: This adds a dependency on torch.cuda.current_device()
 91        # The first solution is often better as it doesn't need this.
 92        return torch.device(f'cuda:{torch.cuda.current_device()}')
 93    return device
 94
 95
 96def get_pytorch_memory_usage_str(device: Optional[Union[int, torch.device]] = None) -> str:
 97    """
 98    Returns a formatted string detailing PyTorch's current and peak memory usage on a given CUDA device.
 99
100    Args:
101        device (Optional[Union[int, torch.device]]): The CUDA device to query.
102                                                     If None, uses the current device.
103
104    Returns:
105        str: A formatted, multi-line string with memory usage details.
106    """
107    if not torch.cuda.is_available():
108        return "CUDA is not available. No GPU memory to report."
109
110    # Resolve the device
111    if device is None:
112        device = get_default_comfy_device()
113
114    # Get memory stats
115    allocated = torch.cuda.memory_allocated(device)
116    reserved = torch.cuda.memory_reserved(device)
117    peak_allocated = torch.cuda.max_memory_allocated(device)
118    peak_reserved = torch.cuda.max_memory_reserved(device)
119
120    # Format into a string
121    return (
122        f"Allocated/Reserved: {format_bytes(allocated):>10s}/{format_bytes(reserved):>10s}"
123        f" (Peak: {format_bytes(peak_allocated):>10s}/{format_bytes(peak_reserved):>10s})\n"
124    )
125
126
127# ##################################################################################
128# # Helper for inference (Target device, offload, eval, no_grad and cuDNN Benchmark)
129# ##################################################################################
130
131@contextlib.contextmanager
132def model_to_target(logger: logging.Logger, model: torch.nn.Module) -> Iterator[None]:
133    """
134    A context manager for safe model inference.
135
136    Handles device placement, inference state (`eval()`, `no_grad()`),
137    optional cuDNN benchmark settings, and safe offloading in a `finally` block
138    to ensure resources are managed even if errors occur.
139
140    The model object is expected to have two optional custom attributes:
141    - ``target_device`` (torch.device): The device to run inference on.
142    - ``cudnn_benchmark_setting`` (bool): The desired cuDNN benchmark state.
143
144    Example Usage::
145
146        with model_to_target(logger, my_model):
147            # Your inference code here
148            output = my_model(input_tensor)
149
150    :param logger: A logger instance for logging state changes.
151    :type logger: logging.Logger
152    :param model: The torch.nn.Module to manage.
153    :type model: torch.nn.Module
154    :yield: None
155    """
156    if not isinstance(model, torch.nn.Module):
157        with torch.no_grad():
158            yield  # The code inside the 'with' statement runs here
159        return
160
161    # --- EXPLICIT TYPE ANNOTATIONS ---
162    # Tell mypy the expected types for these variables.
163    target_device: torch.device
164    original_device: torch.device
165    original_cudnn_benchmark_state: Optional[bool] = None  # Default is to keep the current setting
166
167    # 1. Determine target device from the model object
168    # Use a hasattr check for robustness, and type hinting
169    if hasattr(model, 'target_device') and isinstance(model.target_device, torch.device):
170        # Mypy now understands that model.target_device exists and is a torch.device
171        target_device = model.target_device
172    else:
173        logger.warning("model_to_target: 'target_device' attribute not found or is not a torch.device on the model. "
174                       "Defaulting to model's current device.")
175        # Ensure model has parameters before calling next()
176        try:
177            target_device = next(model.parameters()).device
178        except StopIteration:
179            logger.warning("model_to_target: Model has no parameters. Cannot determine device. Assuming CPU.")
180            target_device = torch.device("cpu")
181
182    # 2. Get CUDNN benchmark setting from the model object (optional)
183    # Use hasattr as this is an optional setting that not all models might have.
184    cudnn_benchmark_enabled: Optional[bool] = None
185    if hasattr(model, 'cudnn_benchmark_setting'):
186        setting = model.cudnn_benchmark_setting
187        if isinstance(setting, bool):
188            cudnn_benchmark_enabled = setting
189        else:
190            logger.warning(f"model.cudnn_benchmark_setting was not a bool, but {type(setting)}. Ignoring.")
191
192    try:
193        # Get original_device from model after ensuring it has parameters
194        original_device = next(model.parameters()).device
195    except StopIteration:
196        original_device = torch.device("cpu")  # Match fallback from above
197
198    is_cuda_target = target_device.type == 'cuda'
199
200    try:
201        # 3. Manage cuDNN benchmark state
202        if (cudnn_benchmark_enabled is not None and is_cuda_target and hasattr(torch.backends, 'cudnn') and
203           torch.backends.cudnn.is_available()):
204            if torch.backends.cudnn.benchmark != cudnn_benchmark_enabled:
205                original_cudnn_benchmark_state = torch.backends.cudnn.benchmark
206                torch.backends.cudnn.benchmark = cudnn_benchmark_enabled
207                logger.debug(f"Temporarily set cuDNN benchmark to {torch.backends.cudnn.benchmark}")
208
209        # 4. Move model to target device if not already there
210        if original_device != target_device:
211            logger.debug(f"Moving model from `{original_device}` to target device `{target_device}` for inference.")
212            model.to(target_device)
213
214        # 5. Set to eval mode and disable gradients for the operation
215        model.eval()
216        with torch.no_grad():
217            yield  # The code inside the 'with' statement runs here
218
219    finally:
220        # 6. Restore original cuDNN benchmark state
221        if original_cudnn_benchmark_state is not None:
222            # This check is sufficient because it will only be not None if we set it inside the try block
223            torch.backends.cudnn.benchmark = original_cudnn_benchmark_state
224            logger.debug(f"Restored cuDNN benchmark to {original_cudnn_benchmark_state}")
225
226        # 7. Offload model back to CPU
227        if with_comfy:
228            offload_device = get_offload_device()
229            try:
230                current_device_after_yield = next(model.parameters()).device
231                if current_device_after_yield != offload_device:
232                    logger.debug(f"Offloading model from `{current_device_after_yield}` to offload device `{offload_device}`.")
233                    model.to(offload_device)
234                    # Clear cache if we were on a CUDA device
235                    if 'cuda' in str(current_device_after_yield):
236                        torch.cuda.empty_cache()
237            except StopIteration:
238                pass  # Model with no parameters doesn't need offloading.
239
240
241# ##################################################################################
242# # Helper for profiling
243# ##################################################################################
244
245
246class TorchProfile(object):
247    def __init__(self, logger: logging.Logger, level: int, msg: str, device: Optional[torch.device] = None):
248        """
249        Used to measure the time and VRAM used for a CUDA task.
250
251        :param logger: The logger instance to display the information.
252        :type logger: logging.Logger
253        :param level: The required verbosity level to display this information.
254        :type level: int
255        :param msg: The message string to display along with the information.
256        :type msg: str
257        :param device: CUDA device, otherwise we use the default
258        :type device: torch.device (optional)
259        """
260        super().__init__()
261        self.enabled = False
262        if get_debug_level(logger) < level:
263            return
264        if not torch.cuda.is_available():
265            logger.debug("CUDA is not available. No GPU memory/timing to report.")
266            return
267        if device is None:
268            device = get_default_comfy_device()
269        self.enabled = True
270        self.logger = logger
271        self.device = device
272        self.msg = msg
273        self.start_event = torch.cuda.Event(enable_timing=True)
274        self.end_event = torch.cuda.Event(enable_timing=True)
275        self.start()
276
277    def start(self):
278        """ Start to profile. Automatically invoked by the constructor """
279        if not self.enabled:
280            return
281        self.logger.debug(f"Starting {self.msg} on {self.device}")
282        torch.cuda.reset_peak_memory_stats(self.device)
283        self.base_vram = torch.cuda.memory_allocated(self.device) / (1024 ** 2)  # in MB
284        self.start_event.record()
285
286    def end(self):
287        """ Stop profiling and show the results """
288        if not self.enabled:
289            return
290        # Stop timer
291        self.end_event.record()
292        torch.cuda.synchronize(self.device)
293        time = self.start_event.elapsed_time(self.end_event)
294        # Get peak used memory
295        mem_peak = torch.cuda.max_memory_allocated(self.device) / (1024 ** 2)  # in MB
296        # Show it
297        self.logger.debug(f"Finished {self.msg} on {self.device}: {time:.6f} ms, peak {mem_peak:.2f} MiB"
298                          f" (+{mem_peak-self.base_vram:.2f} MiB)")
def get_torch_device_options(with_auto: Optional[bool] = False) -> tuple[list[str], str]:
21def get_torch_device_options(with_auto: Optional[bool] = False) -> tuple[list[str], str]:
22    """
23    Detects available PyTorch devices and returns a list and a suitable default.
24
25    Scans for CPU, CUDA devices, and MPS (for Apple Silicon), providing a list
26    of device strings (e.g., 'cpu', 'cuda', 'cuda:0') and a recommended default.
27
28    :param with_auto: The first option is 'AUTO', intended to use the current ComfyUI device
29    :type with_auto: bool (optional)
30
31    :return: A tuple containing the list of available device strings and the
32             recommended default device string.
33    :rtype: tuple[list[str], str]
34    """
35    # We always have CPU
36    default = "cpu"
37    options = [default]
38    # Do we have CUDA?
39    if torch.cuda.is_available():
40        default = "cuda"
41        options.append(default)
42        for i in range(torch.cuda.device_count()):
43            options.append(f"cuda:{i}")  # Specific CUDA devices
44    # Is this a Mac?
45    if torch.backends.mps.is_available() and torch.backends.mps.is_built():
46        options.append("mps")
47        if default == "cpu":
48            default = "mps"
49    # AUTO means: the one used by ComfyUI
50    if with_auto:
51        options.insert(0, "AUTO")
52    return options, default

Detects available PyTorch devices and returns a list and a suitable default.

Scans for CPU, CUDA devices, and MPS (for Apple Silicon), providing a list of device strings (e.g., 'cpu', 'cuda', 'cuda:0') and a recommended default.

Parameters
  • with_auto: The first option is 'AUTO', intended to use the current ComfyUI device
Returns

A tuple containing the list of available device strings and the recommended default device string.

def get_default_comfy_device() -> torch.device:
55def get_default_comfy_device() -> torch.device:
56    return get_canonical_device(mm.get_torch_device() if with_comfy else torch.cuda.current_device())
def get_offload_device() -> torch.device:
59def get_offload_device() -> torch.device:
60    """
61    Gets the appropriate device for offloading models.
62
63    Uses `comfy.model_management.unet_offload_device()` if available in a ComfyUI
64    environment, otherwise defaults to the CPU.
65
66    :return: The torch.device object to use for offloading.
67    :rtype: torch.device
68    """
69    if with_comfy:
70        return cast(torch.device, mm.unet_offload_device())
71    return torch.device("cpu")

Gets the appropriate device for offloading models.

Uses comfy.model_management.unet_offload_device() if available in a ComfyUI environment, otherwise defaults to the CPU.

Returns

The torch.device object to use for offloading.

def get_canonical_device(device: str | torch.device) -> torch.device:
74def get_canonical_device(device: str | torch.device) -> torch.device:
75    """
76    Converts a device string or object into a canonical torch.device object.
77
78    Ensures that a device string like 'cuda' is converted to its fully
79    indexed form, e.g., 'cuda:0', by checking the current default device.
80
81    :param device: The device identifier to canonicalize.
82    :type device: str | torch.device
83    :return: A torch.device object with an explicit index if applicable.
84    :rtype: torch.device
85    """
86    if not isinstance(device, torch.device):
87        device = torch.device(device)
88
89    # If it's a CUDA device and no index is specified, get the default one.
90    if device.type == 'cuda' and device.index is None:
91        # NOTE: This adds a dependency on torch.cuda.current_device()
92        # The first solution is often better as it doesn't need this.
93        return torch.device(f'cuda:{torch.cuda.current_device()}')
94    return device

Converts a device string or object into a canonical torch.device object.

Ensures that a device string like 'cuda' is converted to its fully indexed form, e.g., 'cuda:0', by checking the current default device.

Parameters
  • device: The device identifier to canonicalize.
Returns

A torch.device object with an explicit index if applicable.

def get_pytorch_memory_usage_str(device: Union[int, torch.device, NoneType] = None) -> str:
 97def get_pytorch_memory_usage_str(device: Optional[Union[int, torch.device]] = None) -> str:
 98    """
 99    Returns a formatted string detailing PyTorch's current and peak memory usage on a given CUDA device.
100
101    Args:
102        device (Optional[Union[int, torch.device]]): The CUDA device to query.
103                                                     If None, uses the current device.
104
105    Returns:
106        str: A formatted, multi-line string with memory usage details.
107    """
108    if not torch.cuda.is_available():
109        return "CUDA is not available. No GPU memory to report."
110
111    # Resolve the device
112    if device is None:
113        device = get_default_comfy_device()
114
115    # Get memory stats
116    allocated = torch.cuda.memory_allocated(device)
117    reserved = torch.cuda.memory_reserved(device)
118    peak_allocated = torch.cuda.max_memory_allocated(device)
119    peak_reserved = torch.cuda.max_memory_reserved(device)
120
121    # Format into a string
122    return (
123        f"Allocated/Reserved: {format_bytes(allocated):>10s}/{format_bytes(reserved):>10s}"
124        f" (Peak: {format_bytes(peak_allocated):>10s}/{format_bytes(peak_reserved):>10s})\n"
125    )

Returns a formatted string detailing PyTorch's current and peak memory usage on a given CUDA device.

Args: device (Optional[Union[int, torch.device]]): The CUDA device to query. If None, uses the current device.

Returns: str: A formatted, multi-line string with memory usage details.

@contextlib.contextmanager
def model_to_target( logger: logging.Logger, model: torch.nn.modules.module.Module) -> Iterator[NoneType]:
132@contextlib.contextmanager
133def model_to_target(logger: logging.Logger, model: torch.nn.Module) -> Iterator[None]:
134    """
135    A context manager for safe model inference.
136
137    Handles device placement, inference state (`eval()`, `no_grad()`),
138    optional cuDNN benchmark settings, and safe offloading in a `finally` block
139    to ensure resources are managed even if errors occur.
140
141    The model object is expected to have two optional custom attributes:
142    - ``target_device`` (torch.device): The device to run inference on.
143    - ``cudnn_benchmark_setting`` (bool): The desired cuDNN benchmark state.
144
145    Example Usage::
146
147        with model_to_target(logger, my_model):
148            # Your inference code here
149            output = my_model(input_tensor)
150
151    :param logger: A logger instance for logging state changes.
152    :type logger: logging.Logger
153    :param model: The torch.nn.Module to manage.
154    :type model: torch.nn.Module
155    :yield: None
156    """
157    if not isinstance(model, torch.nn.Module):
158        with torch.no_grad():
159            yield  # The code inside the 'with' statement runs here
160        return
161
162    # --- EXPLICIT TYPE ANNOTATIONS ---
163    # Tell mypy the expected types for these variables.
164    target_device: torch.device
165    original_device: torch.device
166    original_cudnn_benchmark_state: Optional[bool] = None  # Default is to keep the current setting
167
168    # 1. Determine target device from the model object
169    # Use a hasattr check for robustness, and type hinting
170    if hasattr(model, 'target_device') and isinstance(model.target_device, torch.device):
171        # Mypy now understands that model.target_device exists and is a torch.device
172        target_device = model.target_device
173    else:
174        logger.warning("model_to_target: 'target_device' attribute not found or is not a torch.device on the model. "
175                       "Defaulting to model's current device.")
176        # Ensure model has parameters before calling next()
177        try:
178            target_device = next(model.parameters()).device
179        except StopIteration:
180            logger.warning("model_to_target: Model has no parameters. Cannot determine device. Assuming CPU.")
181            target_device = torch.device("cpu")
182
183    # 2. Get CUDNN benchmark setting from the model object (optional)
184    # Use hasattr as this is an optional setting that not all models might have.
185    cudnn_benchmark_enabled: Optional[bool] = None
186    if hasattr(model, 'cudnn_benchmark_setting'):
187        setting = model.cudnn_benchmark_setting
188        if isinstance(setting, bool):
189            cudnn_benchmark_enabled = setting
190        else:
191            logger.warning(f"model.cudnn_benchmark_setting was not a bool, but {type(setting)}. Ignoring.")
192
193    try:
194        # Get original_device from model after ensuring it has parameters
195        original_device = next(model.parameters()).device
196    except StopIteration:
197        original_device = torch.device("cpu")  # Match fallback from above
198
199    is_cuda_target = target_device.type == 'cuda'
200
201    try:
202        # 3. Manage cuDNN benchmark state
203        if (cudnn_benchmark_enabled is not None and is_cuda_target and hasattr(torch.backends, 'cudnn') and
204           torch.backends.cudnn.is_available()):
205            if torch.backends.cudnn.benchmark != cudnn_benchmark_enabled:
206                original_cudnn_benchmark_state = torch.backends.cudnn.benchmark
207                torch.backends.cudnn.benchmark = cudnn_benchmark_enabled
208                logger.debug(f"Temporarily set cuDNN benchmark to {torch.backends.cudnn.benchmark}")
209
210        # 4. Move model to target device if not already there
211        if original_device != target_device:
212            logger.debug(f"Moving model from `{original_device}` to target device `{target_device}` for inference.")
213            model.to(target_device)
214
215        # 5. Set to eval mode and disable gradients for the operation
216        model.eval()
217        with torch.no_grad():
218            yield  # The code inside the 'with' statement runs here
219
220    finally:
221        # 6. Restore original cuDNN benchmark state
222        if original_cudnn_benchmark_state is not None:
223            # This check is sufficient because it will only be not None if we set it inside the try block
224            torch.backends.cudnn.benchmark = original_cudnn_benchmark_state
225            logger.debug(f"Restored cuDNN benchmark to {original_cudnn_benchmark_state}")
226
227        # 7. Offload model back to CPU
228        if with_comfy:
229            offload_device = get_offload_device()
230            try:
231                current_device_after_yield = next(model.parameters()).device
232                if current_device_after_yield != offload_device:
233                    logger.debug(f"Offloading model from `{current_device_after_yield}` to offload device `{offload_device}`.")
234                    model.to(offload_device)
235                    # Clear cache if we were on a CUDA device
236                    if 'cuda' in str(current_device_after_yield):
237                        torch.cuda.empty_cache()
238            except StopIteration:
239                pass  # Model with no parameters doesn't need offloading.

A context manager for safe model inference.

Handles device placement, inference state (eval(), no_grad()), optional cuDNN benchmark settings, and safe offloading in a finally block to ensure resources are managed even if errors occur.

The model object is expected to have two optional custom attributes:

  • target_device (torch.device): The device to run inference on.
  • cudnn_benchmark_setting (bool): The desired cuDNN benchmark state.

Example Usage::

with model_to_target(logger, my_model):
    # Your inference code here
    output = my_model(input_tensor)
Parameters
  • logger: A logger instance for logging state changes.
  • model: The torch.nn.Module to manage. :yield: None
class TorchProfile:
247class TorchProfile(object):
248    def __init__(self, logger: logging.Logger, level: int, msg: str, device: Optional[torch.device] = None):
249        """
250        Used to measure the time and VRAM used for a CUDA task.
251
252        :param logger: The logger instance to display the information.
253        :type logger: logging.Logger
254        :param level: The required verbosity level to display this information.
255        :type level: int
256        :param msg: The message string to display along with the information.
257        :type msg: str
258        :param device: CUDA device, otherwise we use the default
259        :type device: torch.device (optional)
260        """
261        super().__init__()
262        self.enabled = False
263        if get_debug_level(logger) < level:
264            return
265        if not torch.cuda.is_available():
266            logger.debug("CUDA is not available. No GPU memory/timing to report.")
267            return
268        if device is None:
269            device = get_default_comfy_device()
270        self.enabled = True
271        self.logger = logger
272        self.device = device
273        self.msg = msg
274        self.start_event = torch.cuda.Event(enable_timing=True)
275        self.end_event = torch.cuda.Event(enable_timing=True)
276        self.start()
277
278    def start(self):
279        """ Start to profile. Automatically invoked by the constructor """
280        if not self.enabled:
281            return
282        self.logger.debug(f"Starting {self.msg} on {self.device}")
283        torch.cuda.reset_peak_memory_stats(self.device)
284        self.base_vram = torch.cuda.memory_allocated(self.device) / (1024 ** 2)  # in MB
285        self.start_event.record()
286
287    def end(self):
288        """ Stop profiling and show the results """
289        if not self.enabled:
290            return
291        # Stop timer
292        self.end_event.record()
293        torch.cuda.synchronize(self.device)
294        time = self.start_event.elapsed_time(self.end_event)
295        # Get peak used memory
296        mem_peak = torch.cuda.max_memory_allocated(self.device) / (1024 ** 2)  # in MB
297        # Show it
298        self.logger.debug(f"Finished {self.msg} on {self.device}: {time:.6f} ms, peak {mem_peak:.2f} MiB"
299                          f" (+{mem_peak-self.base_vram:.2f} MiB)")
TorchProfile( logger: logging.Logger, level: int, msg: str, device: Optional[torch.device] = None)
248    def __init__(self, logger: logging.Logger, level: int, msg: str, device: Optional[torch.device] = None):
249        """
250        Used to measure the time and VRAM used for a CUDA task.
251
252        :param logger: The logger instance to display the information.
253        :type logger: logging.Logger
254        :param level: The required verbosity level to display this information.
255        :type level: int
256        :param msg: The message string to display along with the information.
257        :type msg: str
258        :param device: CUDA device, otherwise we use the default
259        :type device: torch.device (optional)
260        """
261        super().__init__()
262        self.enabled = False
263        if get_debug_level(logger) < level:
264            return
265        if not torch.cuda.is_available():
266            logger.debug("CUDA is not available. No GPU memory/timing to report.")
267            return
268        if device is None:
269            device = get_default_comfy_device()
270        self.enabled = True
271        self.logger = logger
272        self.device = device
273        self.msg = msg
274        self.start_event = torch.cuda.Event(enable_timing=True)
275        self.end_event = torch.cuda.Event(enable_timing=True)
276        self.start()

Used to measure the time and VRAM used for a CUDA task.

Parameters
  • logger: The logger instance to display the information.
  • level: The required verbosity level to display this information.
  • msg: The message string to display along with the information.
  • device: CUDA device, otherwise we use the default
enabled
logger
device
msg
start_event
end_event
def start(self):
278    def start(self):
279        """ Start to profile. Automatically invoked by the constructor """
280        if not self.enabled:
281            return
282        self.logger.debug(f"Starting {self.msg} on {self.device}")
283        torch.cuda.reset_peak_memory_stats(self.device)
284        self.base_vram = torch.cuda.memory_allocated(self.device) / (1024 ** 2)  # in MB
285        self.start_event.record()

Start to profile. Automatically invoked by the constructor

def end(self):
287    def end(self):
288        """ Stop profiling and show the results """
289        if not self.enabled:
290            return
291        # Stop timer
292        self.end_event.record()
293        torch.cuda.synchronize(self.device)
294        time = self.start_event.elapsed_time(self.end_event)
295        # Get peak used memory
296        mem_peak = torch.cuda.max_memory_allocated(self.device) / (1024 ** 2)  # in MB
297        # Show it
298        self.logger.debug(f"Finished {self.msg} on {self.device}: {time:.6f} ms, peak {mem_peak:.2f} MiB"
299                          f" (+{mem_peak-self.base_vram:.2f} MiB)")

Stop profiling and show the results