DDPM(Denoising Diffusion Probabilistic Models)
笔记来源:
 1.Denoising Diffusion Probabilistic Models
 2.大白话AI | 图像生成模型DDPM | 扩散模型 | 生成模型 | 概率扩散去噪生成模型
 3.pytorch-stable-diffusion
扩散模型正向过程(Forward Diffusion Process)

 给某张图片加噪的具体操作

 由前一个  
      
       
        
         
         
           x 
          
          
          
            t 
           
          
            − 
           
          
            1 
           
          
         
        
       
         x_{t-1} 
        
       
     xt−1 推导后一个  
      
       
        
         
         
           x 
          
         
           t 
          
         
        
       
         x_t 
        
       
     xt 
 
 经过一番推导(详见下文),我们直接由第一个  
      
       
        
         
         
           x 
          
         
           0 
          
         
        
       
         x_0 
        
       
     x0 推导第  
      
       
        
        
          t 
         
        
       
         t 
        
       
     t 个结果  
      
       
        
         
         
           x 
          
         
           t 
          
         
        
       
         x_t 
        
       
     xt 
DDPM的主要作用:
 (1) Add noise to clear image  
      
       
        
         
         
           x 
          
         
           0 
          
         
        
       
         x_0 
        
       
     x0
 (2) calculate  
      
       
        
         
          
          
            μ 
           
          
            t 
           
          
         
           ~ 
          
         
        
       
         \tilde{\mu_t} 
        
       
     μt~ (mean) and  
      
       
        
         
          
          
            β 
           
          
            t 
           
          
         
           ~ 
          
         
        
       
         \tilde{\beta_t} 
        
       
     βt~ (variance) for distribution  
      
       
        
        
          q 
         
        
          ( 
         
         
         
           x 
          
          
          
            t 
           
          
            − 
           
          
            1 
           
          
         
        
          ∣ 
         
         
         
           x 
          
         
           t 
          
         
        
          , 
         
         
         
           x 
          
         
           0 
          
         
        
          ) 
         
        
          = 
         
        
          N 
         
        
          ( 
         
         
         
           x 
          
          
          
            t 
           
          
            − 
           
          
            1 
           
          
         
        
          ; 
         
         
          
          
            μ 
           
          
            t 
           
          
         
           ~ 
          
         
        
          , 
         
         
          
          
            β 
           
          
            t 
           
          
         
           ~ 
          
         
        
          I 
         
        
          ) 
         
        
       
         q(x_{t-1}|x_t,x_0) = N(x_{t-1};\tilde{\mu_t},\tilde{\beta_t}I) 
        
       
     q(xt−1∣xt,x0)=N(xt−1;μt~,βt~I)
 (3) update  
      
       
        
         
          
          
            μ 
           
          
            t 
           
          
         
           ~ 
          
         
        
       
         \tilde{\mu_t} 
        
       
     μt~ (mean)
 
(1) Add noise to clear image using function def add_noise()
 
 上图加噪公式的推导过程见下图

 实现 add_noise(clear image: : 
      
       
        
         
         
           x 
          
         
           0 
          
         
        
       
         x_0 
        
       
     x0, timesteps: t) 
class DDPMSampler:
		def __init__(...):
				...
		def set_inference_timesteps(...): # Set the number of inference timesteps for the DDPM model.
				...
		def _get_previous_timestep(...): # Calculate the previous timestep for the given timestep
				...
		def _get_variance(...): # Calculate the variance for the given timestep
				...
		def set_strength(...): # Set how much noise to add to the input image.
				...
		
		def step(...): # Perform one step of the diffusion (forward) process.
			  ...
			  
		def add_noise( # Add noise to the original samples according to the diffusion (forward) process.
        self,
        original_samples: torch.FloatTensor,
        timesteps: torch.IntTensor,
    ) -> torch.FloatTensor:
        """
            Add noise to the original samples according to the diffusion process.
            Args:
            - original_samples (torch.FloatTensor): The original samples (images) to which noise will be added.
            - timesteps (torch.IntTensor): The timesteps at which the noise will be added.
            Returns:
            - torch.FloatTensor: The noisy samples.
        """
        # Retrieve the cumulative product of alphas on the same device and with the same dtype as the original samples
        alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype)
        # Move timesteps to the same device as the original samples
        timesteps = timesteps.to(original_samples.device)
        # Compute the square root of the cumulative product of alphas for the given timesteps
        # sqert{hat_alpha_t}
        sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
        # Flatten sqrt_alpha_prod to ensure it's a 1D tensor
        sqrt_alpha_prod = sqrt_alpha_prod.flatten()
        # Reshape sqrt_alpha_prod to match the dimensions of original_samples
        while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
            sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
        # Compute the square root of (1 - cumulative product of alphas) for the given timesteps
        # sqrt{1-hat_alpha_t}
        sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
        # Flatten sqrt_one_minus_alpha_prod to ensure it's a 1D tensor
        sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
        # Reshape sqrt_one_minus_alpha_prod to match the dimensions of original_samples
        # checks if the number of dimensions of sqrt_alpha_prod is less than the number of dimensions of original_samples
        while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
            sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
        # Sample from q(x_t | x_0) as in equation (4) of https://arxiv.org/pdf/2006.11239.pdf
        # Because N(mu, sigma) = X can be obtained by X = mu + sigma * N(0, 1)
        # here mu = sqrt_alpha_prod * original_samples and sigma = sqrt_one_minus_alpha_prod
        # Sample noise from a normal distribution with the same shape as the original samples
        noise = torch.randn(original_samples.shape, generator=self.generator, device=original_samples.device, dtype=original_samples.dtype)
        # sqrt_alpha_prod * original_samples (This represents the mean component in the noisy sample calculation.)
        # This term scales the original samples by the square root of the cumulative product of alphas for the given timesteps.
        # sqrt_one_minus_alpha_prod * noise (This represents the variance component in the noisy sample calculation.)
        # This term scales the random noise by the square root of (1 - cumulative product of alphas) for the given timesteps.
        # sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
        # adds the scaled noise to the scaled original samples. This operation forms the noisy samples,
        # where the influence of the original samples and the noise varies according to the timesteps.
        # x_t = sqrt{hat_alpha_t} * x_0 + sqrt{1-hat_alpha_t} * epsilon
        noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
        return noisy_samples
(2) calculate  
      
       
        
         
          
          
            μ 
           
          
            t 
           
          
         
           ~ 
          
         
        
       
         \tilde{\mu_t} 
        
       
     μt~ (mean) and  
      
       
        
         
          
          
            β 
           
          
            t 
           
          
         
           ~ 
          
         
        
       
         \tilde{\beta_t} 
        
       
     βt~ (variance) for distribution  
      
       
        
        
          q 
         
        
          ( 
         
         
         
           x 
          
          
          
            t 
           
          
            − 
           
          
            1 
           
          
         
        
          ∣ 
         
         
         
           x 
          
         
           t 
          
         
        
          , 
         
         
         
           x 
          
         
           0 
          
         
        
          ) 
         
        
          = 
         
        
          N 
         
        
          ( 
         
         
         
           x 
          
          
          
            t 
           
          
            − 
           
          
            1 
           
          
         
        
          ; 
         
         
          
          
            μ 
           
          
            t 
           
          
         
           ~ 
          
         
        
          , 
         
         
          
          
            β 
           
          
            t 
           
          
         
           ~ 
          
         
        
          I 
         
        
          ) 
         
        
       
         q(x_{t-1}|x_t,x_0) = N(x_{t-1};\tilde{\mu_t},\tilde{\beta_t}I) 
        
       
     q(xt−1∣xt,x0)=N(xt−1;μt~,βt~I)  
     
      
       
       
         Note: N(output; mean, variance) 
        
       
      
        \text{Note: N(output; mean, variance)} 
       
      
    Note: N(output; mean, variance)
 
 
求上述概率分布的均值和方差的推导过程见下图

 实现 _get_variance() 计算方差,实现 step() 计算均值并更新均值
class DDPMSampler:
		def __init__(...):
				...
		def set_inference_timesteps(...): # Set the number of inference timesteps for the DDPM model.
				...
		def _get_previous_timestep(...): # Calculate the previous timestep for the given timestep
				...
		def _get_variance(...): # Calculate the variance for the given timestep
				...
		def set_strength(...): # Set how much noise to add to the input image.
				...
    def step(self, timestep: int, latents: torch.Tensor, model_output: torch.Tensor):
        """
            Perform one step of the diffusion (forward) process.
            Args:
            - timestep (int): The current timestep during diffusion.
            - latents (torch.Tensor): The latent representation of the input.
            - model_output (torch.Tensor): The output from the diffusion model.
        """
        t = timestep
        # Get the previous timestep using the _get_previous_timestep method
        prev_t = self._get_previous_timestep(t)
        # 1. compute alphas, betas
        # hat_alpha_t
        alpha_prod_t = self.alphas_cumprod[t]
        # hat_alpha_{t-1}
        alpha_prod_t_prev = self.alphas_cumprod[prev_t] if prev_t >= 0 else self.one
        # hat_beta_t = 1 - hat_alpha_t
        beta_prod_t = 1 - alpha_prod_t
        # hat_beta_{t-1} = 1 - hat_alpha_{t-1}
        beta_prod_t_prev = 1 - alpha_prod_t_prev
        # alpha_prod_t / alpha_prod_t_prev = (alpha_t*alpha_{t-1}*...*alpha_1) / (alpha_{t-1}*...*alpha_1) = alpha_t
        current_alpha_t = alpha_prod_t / alpha_prod_t_prev
        # beta_t = 1- alpha_t
        current_beta_t = 1 - current_alpha_t
        # 2. compute predicted original sample from predicted noise also called
        # "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf
        # x_t = sqrt{1 - hat_alpha_t}* epsilon + sqrt{hat_alpha_t} * x_0
        # x_0 = (x_t - sqrt{1 - hat_alpha_t} * epsilon(x_t)) / sqrt{hat_alpha_t}
        # x_0 = (x_t - sqrt{hat_beta_t} * epsilon(x_t)) / sqrt{hat_alpha_t}
        pred_original_sample = (latents - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
        # 4. Compute coefficients for pred_original_sample x_0 and current sample x_t
        # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
        # x_{t-1} ~ p_{theta}(x_{t-1} | x_t) a distribution with regard to x_{t-1} during reverse process
        # = N (1/sqrt{alpha_t} * x_t - (beta_t)/(sqrt{alpha_t}sqrt{1-hat_alpha_t} * epsilon(x_t,t))
        #      , (beta_t * 1-hat_alpha_{t-1})/(1-hat_alpha_{t}) )
        # x_{t-1} ~ q(x_{t-1} | x_t,x_0) a distribution with regard to x_{t-1} during forward process
        # = N (frac{sqrt{hat_alpha_{t-1}}beta_t}{1-hat_alpha_t}x_0+frac{sqrt{alpha_t}(1-hat_alpha_{t-1})}{1-hat_alphat_t}*x_t
        #      , (beta_t * 1-hat_alpha_{t-1})/(1-hat_alpha_{t}))
        # frac{sqrt{hat_alpha_{t-1}}beta_t}{1-hat_alpha_t}
        pred_original_sample_coeff = (alpha_prod_t_prev ** (0.5) * current_beta_t) / beta_prod_t
        # frac{sqrt{alpha_t}(1-hat_alpha_{t-1})}{1-hat_alphat_t}
        current_sample_coeff = current_alpha_t ** (0.5) * beta_prod_t_prev / beta_prod_t
        # 5. Compute predicted previous sample µ_t
        # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
        # pred_mu_t = coeff_1 * x_0 + coeff_2 * x_t
        pred_prev_sample = pred_original_sample_coeff * pred_original_sample + current_sample_coeff * latents
		# 6. Update pred_mu_t according to pred_beta_t
			 ...
	def add_noise(...):
			 ...
为何我们要计算概率分布 
      
       
        
        
          q 
         
        
          ( 
         
         
         
           x 
          
          
          
            t 
           
          
            − 
           
          
            1 
           
          
         
        
          ∣ 
         
         
         
           x 
          
         
           t 
          
         
        
          , 
         
         
         
           x 
          
         
           0 
          
         
        
          ) 
         
        
       
         q(x_{t-1}|x_t,x_0) 
        
       
     q(xt−1∣xt,x0)?
 Stable Diffusion 的 Loss Funtion 推导中会出现一个KL散度项,此项衡量两个分布的相似性,以此来不断引导反向过程生成最终的图片,具体解释见后续博客
 

 

 (3) update  
      
       
        
         
          
          
            μ 
           
          
            t 
           
          
         
           ~ 
          
         
        
       
         \tilde{\mu_t} 
        
       
     μt~ (mean)
  
      
       
        
         
          
          
            μ 
           
          
            ~ 
           
          
         
           t 
          
         
        
          = 
         
         
          
          
            μ 
           
          
            ~ 
           
          
         
           t 
          
         
        
          + 
         
         
          
           
            
            
              β 
             
            
              t 
             
            
           
             ~ 
            
           
          
            2 
           
          
         
        
          × 
         
        
          ϵ 
         
        
            
         
         
         
           ( 
          
         
           Note:  
          
         
           ϵ 
          
         
           ∼ 
          
         
           N 
          
         
           ( 
          
         
           0 
          
         
           , 
          
         
           1 
          
         
           ) 
          
         
           ) 
          
         
         
         
          
          
            μ 
           
          
            ~ 
           
          
         
           t 
          
         
        
          = 
         
         
          
          
            μ 
           
          
            ~ 
           
          
         
           t 
          
         
        
          + 
         
         
          
          
            β 
           
          
            t 
           
          
         
           ~ 
          
         
        
          × 
         
        
          ϵ 
         
        
       
         \tilde{\mu}_t = \tilde{\mu}_t + \sqrt{\tilde{\beta_t}^2}×\epsilon\ \left(\text{Note: }\epsilon \sim N(0,1)\right)\\ \tilde{\mu}_t = \tilde{\mu}_t + \tilde{\beta_t}×\epsilon 
        
       
     μ~t=μ~t+βt~2×ϵ (Note: ϵ∼N(0,1))μ~t=μ~t+βt~×ϵ
class DDPMSampler:
		def __init__(...):
				...
		def set_inference_timesteps(...): # Set the number of inference timesteps for the DDPM model.
				...
		def _get_previous_timestep(...): # Calculate the previous timestep for the given timestep
				...
		def _get_variance(...): # Calculate the variance for the given timestep
				...
		def set_strength(...): # Set how much noise to add to the input image.
				...
    	def step(self, timestep: int, latents: torch.Tensor, model_output: torch.Tensor):
        	"""
            Perform one step of the diffusion (forward) process.
            Args:
            - timestep (int): The current timestep during diffusion.
            - latents (torch.Tensor): The latent representation of the input.
            - model_output (torch.Tensor): The output from the diffusion model.
        	"""
        		...
        		...
        		...
			# 6. Update pred_mu_t according to pred_beta_t
			 	variance = 0
        	if t > 0:
            	# Get the device of model_output
            	device = model_output.device
            	# Generate random noise with the same shape as model_output
            	noise = torch.randn(model_output.shape, generator=self.generator, device=device, dtype=model_output.dtype)
            	# Compute the variance for the current timestep as per formula (7) from https://arxiv.org/pdf/2006.11239.pdf
            	# sqrt{sigma_t}*epsilon
            	variance = (self._get_variance(t) ** 0.5) * noise
            	# Add the variance (multiplied by noise) to the predicted previous sample
            	# sample from N(mu, sigma) = X can be obtained by X = mu + sigma * N(0, 1)
            	# the variable "variance" is already multiplied by the noise N(0, 1)
            	# For t > 0, compute predicted variance βt (see formula (6) and (7) from https://arxiv.org/pdf/2006.11239.pdf)
            	# and sample from it to get previous sample
            	# pred_mu_t = pred_mu_t + sqrt{pred_beta_t^2} * epsilon (Note:epsilon ~N(0,1))
            	pred_prev_sample = pred_prev_sample + variance
        return pred_prev_sample
	def add_noise(...):
			 ...
All of codes about DDPM (ddpm,.py)
import torch
import numpy as np
'''
# Forward Process
# Add noise to clear image  and calculate pred_mu_t and pred_beta_t for distribution and update pred_mu_t
# (1) Add noise to clear image using function def add_noise()
#  x_t = sqrt{hat_alpha_t} * x_0 + sqrt{1-hat_alpha_t} * epsilon (Note:epsilon~N(0,1))
# see formula (4) from https://arxiv.org/pdf/2006.11239.pdf
# (2) calculate pred_mu_t and pred_beta_t for distribution
# q(x_{t-1}|x_t,x_0) = N(pred_mu_t,pred_beta_t*I)
# def step()
# predicted_mu_t = coeff_1 * x_0 + coeff_2 * x_t
# def _get_variance()
# predicted_variance beta_t=(1-hat_alpha_{t-1})/(1-hat_alpha_t)*beta_t
# (3) update pred_mu_t
# def step()
# update pred_mu_t = pred_mu_t + sqrt{pred_beta_t^2} * noise (Note:noise ~ N(0,1))
# see formula (6) and (7) from https://arxiv.org/pdf/2006.11239.pdf
'''
class DDPMSampler:
    def __init__(self, generator: torch.Generator, num_training_steps=1000, beta_start: float = 0.00085, beta_end: float = 0.0120):
        # Params "beta_start" and "beta_end" taken from:
        # https://github.com/CompVis/stable-diffusion/blob/21f890f9da3cfbeaba8e2ac3c425ee9e998d5229/configs/stable-diffusion/v1-inference.yaml#L5C8-L5C8
        # For the naming conventions, refer to the DDPM paper (https://arxiv.org/pdf/2006.11239.pdf)
        """
            Initialize the DDPM (Denoising Diffusion Probabilistic Model) parameters.
            Args:
            - generator (torch.Generator): A PyTorch random number generator.
            - num_training_steps (int, optional): Number of training steps. Default is 1000.
            - beta_start (float, optional): The starting value of beta. Default is 0.00085.
            - beta_end (float, optional): The ending value of beta. Default is 0.0120.
        """
        self.betas = torch.linspace(beta_start ** 0.5, beta_end ** 0.5, num_training_steps, dtype=torch.float32) ** 2
        # alppha = 1 - beta
        self.alphas = 1.0 - self.betas
        # hat_alpha = alpha_t * alpha_ {t-1} * ... * alpha_2 * alpha_1
        self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
        # Define a tensor representing the value 1.0
        self.one = torch.tensor(1.0)
        # Store the generator for random number generation
        self.generator = generator
        # Number of training timesteps
        self.num_train_timesteps = num_training_steps
        # Create a tensor of timesteps in reverse order
        self.timesteps = torch.from_numpy(np.arange(0, num_training_steps)[::-1].copy())
    def set_inference_timesteps(self, num_inference_steps=50):
        """
            Set the number of inference timesteps for the DDPM model.
            Args:
            - num_inference_steps (int, optional): Number of steps to use during inference. Default is 50.
        """
        # Store the number of inference steps
        self.num_inference_steps = num_inference_steps
        # Calculate the ratio between training timesteps and inference timesteps
        step_ratio = self.num_train_timesteps // self.num_inference_steps
        # Generate an array of timesteps for inference:
        # - np.arange(0, num_inference_steps): Create an array from 0 to num_inference_steps-1
        # - Multiply by step_ratio to space out the timesteps
        # - round() to ensure the timesteps are integers
        # - [::-1] to reverse the order, as inference typically proceeds backward through the timesteps
        # - copy() to ensure the array is contiguous in memory
        # - astype(np.int64) to ensure the timesteps are of type int64, which is compatible with PyTorch
        timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64)
        # Convert the numpy array of timesteps to a PyTorch tensor
        self.timesteps = torch.from_numpy(timesteps)
    def _get_previous_timestep(self, timestep: int) -> int:
        """
            Calculate the previous timestep for the given timestep during inference.
            Args:
            - timestep (int): The current timestep during inference.
            Returns:
            - int: The previous timestep during inference.
        """
        # Calculate the previous timestep by subtracting the step ratio from the current timestep.
        # The step ratio is the integer division of the total number of training timesteps by the number of inference timesteps.
        # timstep t-1 = timestep t - ratio
        prev_t = timestep - self.num_train_timesteps // self.num_inference_steps
        return prev_t
    
    def _get_variance(self, timestep: int) -> torch.Tensor:
        """
            Calculate the variance for the given timestep during inference.
            Args:
            - timestep (int): The current timestep during inference.
            Returns:
            - torch.Tensor: The variance for the given timestep.
        """
        # Get the previous timestep using the _get_previous_timestep method
        prev_t = self._get_previous_timestep(timestep)
        # Retrieve the cumulative product of alphas at the current and previous timesteps
        # hat_alpha_t
        alpha_prod_t = self.alphas_cumprod[timestep]
        # hat_alpha_{t-1}
        alpha_prod_t_prev = self.alphas_cumprod[prev_t] if prev_t >= 0 else self.one
        # alpha_prod_t / alpha_prod_t_prev = (alpha_t*alpha_{t-1}*...*alpha_1) / (alpha_{t-1}*...*alpha_1) = alpha_t
        # beta_t = 1- alpha_t
        current_beta_t = 1 - alpha_prod_t / alpha_prod_t_prev
        # For t > 0, compute predicted variance βt (see formula (6) and (7) from https://arxiv.org/pdf/2006.11239.pdf)
        # and sample from it to get previous sample
        # x_{t-1} ~ P(x_{t-1} | x_t,x_0)
        # = N (mu, sigma)
        # = N (1/sqrt{alpha_t} * x_t - (beta_t)/(sqrt{alpha_t}sqrt{1-hat_alpha_t} * epsilon)
        #      , (beta_t * 1-hat_alpha_{t-1})/(1-hat_alpha_{t}) )
        # x_{t-1} ~ N(pred_prev_sample, variance) == add variance to pred_sample
        variance = (1 - alpha_prod_t_prev) / (1 - alpha_prod_t) * current_beta_t
        # Clamp the variance to ensure it's not zero, as we will take its log later
        variance = torch.clamp(variance, min=1e-20)
        return variance
    
    def set_strength(self, strength=1):
        """
            Set how much noise to add to the input image.
            Args:
            - strength (float, optional): A value between 0 and 1 indicating the amount of noise to add.
            - A strength value close to 1 means the output will be further from the input image (more noise).
            - A strength value close to 0 means the output will be closer to the input image (less noise).
        """
        # Calculate the number of inference steps to skip based on the strength
        # Higher strength means fewer steps skipped (more noise added)
        # start_step is the number of noise levels to skip
        start_step = self.num_inference_steps - int(self.num_inference_steps * strength)
        # Update the timesteps to start from the calculated step
        # This effectively sets the starting point for the noise addition process
        self.timesteps = self.timesteps[start_step:]
        # Store the starting step for reference
        self.start_step = start_step
    def step(self, timestep: int, latents: torch.Tensor, model_output: torch.Tensor):
        """
            Perform one step of the diffusion(forward) process.
            Args:
            - timestep (int): The current timestep during diffusion.
            - latents (torch.Tensor): The latent representation of the input.
            - model_output (torch.Tensor): The output from the diffusion model.
        """
        t = timestep
        # Get the previous timestep using the _get_previous_timestep method
        prev_t = self._get_previous_timestep(t)
        # 1. compute alphas, betas
        # hat_alpha_t
        alpha_prod_t = self.alphas_cumprod[t]
        # hat_alpha_{t-1}
        alpha_prod_t_prev = self.alphas_cumprod[prev_t] if prev_t >= 0 else self.one
        # hat_beta_t = 1 - hat_alpha_t
        beta_prod_t = 1 - alpha_prod_t
        # hat_beta_{t-1} = 1 - hat_alpha_{t-1}
        beta_prod_t_prev = 1 - alpha_prod_t_prev
        # alpha_prod_t / alpha_prod_t_prev = (alpha_t*alpha_{t-1}*...*alpha_1) / (alpha_{t-1}*...*alpha_1) = alpha_t
        current_alpha_t = alpha_prod_t / alpha_prod_t_prev
        # beta_t = 1- alpha_t
        current_beta_t = 1 - current_alpha_t
        # 2. compute predicted original sample from predicted noise also called
        # "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf
        # x_t = sqrt{1 - hat_alpha_t}* epsilon + sqrt{hat_alpha_t} * x_0
        # x_0 = (x_t - sqrt{1 - hat_alpha_t} * epsilon(x_t)) / sqrt{hat_alpha_t}
        # x_0 = (x_t - sqrt{hat_beta_t} * epsilon(x_t)) / sqrt{hat_alpha_t}
        pred_original_sample = (latents - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
        # 4. Compute coefficients for pred_original_sample x_0 and current sample x_t
        # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
        # x_{t-1} ~ p_{theta}(x_{t-1} | x_t) a distribution with regard to x_{t-1} during reverse process
        # = N (1/sqrt{alpha_t} * x_t - (beta_t)/(sqrt{alpha_t}sqrt{1-hat_alpha_t} * epsilon(x_t,t))
        #      , (beta_t * 1-hat_alpha_{t-1})/(1-hat_alpha_{t}) )
        # x_{t-1} ~ q(x_{t-1} | x_t,x_0) a distribution with regard to x_{t-1} during forward process
        # = N (frac{sqrt{hat_alpha_{t-1}}beta_t}{1-hat_alpha_t}x_0+frac{sqrt{alpha_t}(1-hat_alpha_{t-1})}{1-hat_alphat_t}*x_t
        #      , (beta_t * 1-hat_alpha_{t-1})/(1-hat_alpha_{t}))
        # frac{sqrt{hat_alpha_{t-1}}beta_t}{1-hat_alpha_t}
        pred_original_sample_coeff = (alpha_prod_t_prev ** (0.5) * current_beta_t) / beta_prod_t
        # frac{sqrt{alpha_t}(1-hat_alpha_{t-1})}{1-hat_alphat_t}
        current_sample_coeff = current_alpha_t ** (0.5) * beta_prod_t_prev / beta_prod_t
        # 5. Compute predicted previous sample µ_t
        # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
        # pred_mu_t = coeff_1 * x_0 + coeff_2 * x_t
        pred_prev_sample = pred_original_sample_coeff * pred_original_sample + current_sample_coeff * latents
        # 6. Update pred_mu_t according to pred_beta_t
        variance = 0
        if t > 0:
            # Get the device of model_output
            device = model_output.device
            # Generate random noise with the same shape as model_output
            noise = torch.randn(model_output.shape, generator=self.generator, device=device, dtype=model_output.dtype)
            # Compute the variance for the current timestep as per formula (7) from https://arxiv.org/pdf/2006.11239.pdf
            # sqrt{sigma_t}*epsilon
            variance = (self._get_variance(t) ** 0.5) * noise
            # Add the variance (multiplied by noise) to the predicted previous sample
            # sample from N(mu, sigma) = X can be obtained by X = mu + sigma * N(0, 1)
            # the variable "variance" is already multiplied by the noise N(0, 1)
            # For t > 0, compute predicted variance βt (see formula (6) and (7) from https://arxiv.org/pdf/2006.11239.pdf)
            # and sample from it to get previous sample
            # pred_mu_t = pred_mu_t + sqrt{pred_beta_t^2} * epsilon (Note:epsilon ~N(0,1))
            pred_prev_sample = pred_prev_sample + variance
        return pred_prev_sample
    
    def add_noise(
        self,
        original_samples: torch.FloatTensor,
        timesteps: torch.IntTensor,
    ) -> torch.FloatTensor:
        """
            Add noise to the original samples according to the diffusion process.
            Args:
            - original_samples (torch.FloatTensor): The original samples (images) to which noise will be added.
            - timesteps (torch.IntTensor): The timesteps at which the noise will be added.
            Returns:
            - torch.FloatTensor: The noisy samples.
        """
        # Retrieve the cumulative product of alphas on the same device and with the same dtype as the original samples
        alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype)
        # Move timesteps to the same device as the original samples
        timesteps = timesteps.to(original_samples.device)
        # Compute the square root of the cumulative product of alphas for the given timesteps
        # sqert{hat_alpha_t}
        sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
        # Flatten sqrt_alpha_prod to ensure it's a 1D tensor
        sqrt_alpha_prod = sqrt_alpha_prod.flatten()
        # Reshape sqrt_alpha_prod to match the dimensions of original_samples
        while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
            sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
        # Compute the square root of (1 - cumulative product of alphas) for the given timesteps
        # sqrt{1-hat_alpha_t}
        sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
        # Flatten sqrt_one_minus_alpha_prod to ensure it's a 1D tensor
        sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
        # Reshape sqrt_one_minus_alpha_prod to match the dimensions of original_samples
        # checks if the number of dimensions of sqrt_alpha_prod is less than the number of dimensions of original_samples
        while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
            sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
        # Sample from q(x_t | x_0) as in equation (4) of https://arxiv.org/pdf/2006.11239.pdf
        # Because N(mu, sigma) = X can be obtained by X = mu + sigma * N(0, 1)
        # here mu = sqrt_alpha_prod * original_samples and sigma = sqrt_one_minus_alpha_prod
        # Sample noise from a normal distribution with the same shape as the original samples
        noise = torch.randn(original_samples.shape, generator=self.generator, device=original_samples.device, dtype=original_samples.dtype)
        # sqrt_alpha_prod * original_samples (This represents the mean component in the noisy sample calculation.)
        # This term scales the original samples by the square root of the cumulative product of alphas for the given timesteps.
        # sqrt_one_minus_alpha_prod * noise (This represents the variance component in the noisy sample calculation.)
        # This term scales the random noise by the square root of (1 - cumulative product of alphas) for the given timesteps.
        # sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
        # adds the scaled noise to the scaled original samples. This operation forms the noisy samples,
        # where the influence of the original samples and the noise varies according to the timesteps.
        # x_t = sqrt{hat_alpha_t} * x_0 + sqrt{1-hat_alpha_t} * epsilon
        noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
        return noisy_samples



















