The Stable Diffusion model can also be applied to image-to-image generation by passing a text prompt and an initial image to condition the generation of new images.
Guided image synthesis enables everyday users to create and edit photo-realistic images with minimum effort. The key challenge is balancing faithfulness to the user input (e.g., hand-drawn colored strokes) and realism of the synthesized image. Existing GAN-based methods attempt to achieve such balance using either conditional GANs or GAN inversions, which are challenging and often require additional training data or loss functions for individual applications. To address these issues, we introduce a new image synthesis and editing method, Stochastic Differential Editing (SDEdit), based on a diffusion model generative prior, which synthesizes realistic images by iteratively denoising through a stochastic differential equation (SDE). Given an input image with user guide of any type, SDEdit first adds noise to the input, then subsequently denoises the resulting image through the SDE prior to increase its realism. SDEdit does not require task-specific training or inversions and can naturally achieve the balance between realism and faithfulness. SDEdit significantly outperforms state-of-the-art GAN-based methods by up to 98.09% on realism and 91.72% on overall satisfaction scores, according to a human perception study, on multiple tasks, including stroke-based image synthesis and editing as well as image compositing.
Tip
Make sure to check out the Stable Diffusion Tips section to learn how to explore the tradeoff between scheduler speed and quality, and how to reuse pipeline components efficiently!
Pipeline for text guided image-to-image generation using Stable Diffusion.
This model inherits from [DiffusionPipeline]. Check the superclass documentation for the generic methods
implemented for all pipelines (downloading, saving, running on a particular device, etc.).
The pipeline also inherits the following loading methods
[~loaders.TextualInversionLoaderMixin.load_textual_inversion] for loading textual inversion embeddings
[~loaders.LoraLoaderMixin.load_lora_weights] for loading LoRA weights
[~loaders.LoraLoaderMixin.save_lora_weights] for saving LoRA weights
[~loaders.FromSingleFileMixin.from_single_file] for loading .ckpt files
[~loaders.IPAdapterMixin.load_ip_adapter] for loading IP Adapters
PARAMETER
DESCRIPTION
vae
Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.
A UNet2DConditionModel to denoise the encoded image latents.
TYPE:[`UNet2DConditionModel`]
scheduler
A scheduler to be used in combination with unet to denoise the encoded image latents. Can be one of
[DDIMScheduler], [LMSDiscreteScheduler], or [PNDMScheduler].
TYPE:[`SchedulerMixin`]
safety_checker
Classification module that estimates whether generated images could be considered offensive or harmful.
Please refer to the model card for more details
about a model's potential harms.
TYPE:[`StableDiffusionSafetyChecker`]
feature_extractor
A CLIPImageProcessor to extract features from generated images; used as inputs to the safety_checker.
TYPE:[`~transformers.CLIPImageProcessor`]
Source code in mindone/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py
classStableDiffusionImg2ImgPipeline(DiffusionPipeline,StableDiffusionMixin,IPAdapterMixin,TextualInversionLoaderMixin,LoraLoaderMixin,FromSingleFileMixin,):r""" Pipeline for text guided image-to-image generation using Stable Diffusion. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods implemented for all pipelines (downloading, saving, running on a particular device, etc.). The pipeline also inherits the following loading methods: - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings - [`~loaders.LoraLoaderMixin.load_lora_weights`] for loading LoRA weights - [`~loaders.LoraLoaderMixin.save_lora_weights`] for saving LoRA weights - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters Args: vae ([`AutoencoderKL`]): Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. text_encoder ([`~transformers.CLIPTextModel`]): Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). tokenizer ([`~transformers.CLIPTokenizer`]): A `CLIPTokenizer` to tokenize text. unet ([`UNet2DConditionModel`]): A `UNet2DConditionModel` to denoise the encoded image latents. scheduler ([`SchedulerMixin`]): A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. safety_checker ([`StableDiffusionSafetyChecker`]): Classification module that estimates whether generated images could be considered offensive or harmful. Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details about a model's potential harms. feature_extractor ([`~transformers.CLIPImageProcessor`]): A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`. """model_cpu_offload_seq="text_encoder->image_encoder->unet->vae"_optional_components=["safety_checker","feature_extractor","image_encoder"]_exclude_from_cpu_offload=["safety_checker"]_callback_tensor_inputs=["latents","prompt_embeds","negative_prompt_embeds"]def__init__(self,vae:AutoencoderKL,text_encoder:CLIPTextModel,tokenizer:CLIPTokenizer,unet:UNet2DConditionModel,scheduler:KarrasDiffusionSchedulers,safety_checker:StableDiffusionSafetyChecker,feature_extractor:CLIPImageProcessor,image_encoder:CLIPVisionModelWithProjection=None,requires_safety_checker:bool=True,):super().__init__()ifhasattr(scheduler.config,"steps_offset")andscheduler.config.steps_offset!=1:deprecation_message=(f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure ""to update the config accordingly as leaving `steps_offset` might led to incorrect results"" in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"" it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"" file")deprecate("steps_offset!=1","1.0.0",deprecation_message,standard_warn=False)new_config=dict(scheduler.config)new_config["steps_offset"]=1scheduler._internal_dict=FrozenDict(new_config)ifhasattr(scheduler.config,"clip_sample")andscheduler.config.clip_sampleisTrue:deprecation_message=(f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."" `clip_sample` should be set to False in the configuration file. Please make sure to update the"" config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"" future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"" nice if you could open a Pull request for the `scheduler/scheduler_config.json` file")deprecate("clip_sample not set","1.0.0",deprecation_message,standard_warn=False)new_config=dict(scheduler.config)new_config["clip_sample"]=Falsescheduler._internal_dict=FrozenDict(new_config)ifsafety_checkerisNoneandrequires_safety_checker:logger.warning(f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"" results in services or applications open to the public. Both the diffusers team and Hugging Face"" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"" it only for use-cases that involve analyzing network behavior or auditing its results. For more"" information, please have a look at https://github.com/huggingface/diffusers/pull/254 .")ifsafety_checkerisnotNoneandfeature_extractorisNone:raiseValueError("Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead.")is_unet_version_less_0_9_0=hasattr(unet.config,"_diffusers_version")andversion.parse(version.parse(unet.config._diffusers_version).base_version)<version.parse("0.9.0.dev0")is_unet_sample_size_less_64=hasattr(unet.config,"sample_size")andunet.config.sample_size<64ifis_unet_version_less_0_9_0andis_unet_sample_size_less_64:deprecation_message=("The configuration file of the unet has set the default `sample_size` to smaller than"" 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the"" following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"" CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"" \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"" configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"" in the config might lead to incorrect results in future versions. If you have downloaded this"" checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"" the `unet/config.json` file")deprecate("sample_size<64","1.0.0",deprecation_message,standard_warn=False)new_config=dict(unet.config)new_config["sample_size"]=64unet._internal_dict=FrozenDict(new_config)self.register_modules(vae=vae,text_encoder=text_encoder,tokenizer=tokenizer,unet=unet,scheduler=scheduler,safety_checker=safety_checker,feature_extractor=feature_extractor,image_encoder=image_encoder,)self.vae_scale_factor=2**(len(self.vae.config.block_out_channels)-1)self.image_processor=VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)self.register_to_config(requires_safety_checker=requires_safety_checker)def_encode_prompt(self,prompt,num_images_per_prompt,do_classifier_free_guidance,negative_prompt=None,prompt_embeds:Optional[ms.Tensor]=None,negative_prompt_embeds:Optional[ms.Tensor]=None,lora_scale:Optional[float]=None,**kwargs,):deprecation_message="`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple."# noqa: E501deprecate("_encode_prompt()","1.0.0",deprecation_message,standard_warn=False)prompt_embeds_tuple=self.encode_prompt(prompt=prompt,num_images_per_prompt=num_images_per_prompt,do_classifier_free_guidance=do_classifier_free_guidance,negative_prompt=negative_prompt,prompt_embeds=prompt_embeds,negative_prompt_embeds=negative_prompt_embeds,lora_scale=lora_scale,**kwargs,)# concatenate for backwards compprompt_embeds=ops.cat([prompt_embeds_tuple[1],prompt_embeds_tuple[0]])returnprompt_embedsdefencode_prompt(self,prompt,num_images_per_prompt,do_classifier_free_guidance,negative_prompt=None,prompt_embeds:Optional[ms.Tensor]=None,negative_prompt_embeds:Optional[ms.Tensor]=None,lora_scale:Optional[float]=None,clip_skip:Optional[int]=None,):r""" Encodes the prompt into text encoder hidden states. Args: prompt (`str` or `List[str]`, *optional*): prompt to be encoded num_images_per_prompt (`int`): number of images that should be generated per prompt do_classifier_free_guidance (`bool`): whether to use classifier free guidance or not negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). prompt_embeds (`ms.Tensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`ms.Tensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. lora_scale (`float`, *optional*): A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. clip_skip (`int`, *optional*): Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that the output of the pre-final layer will be used for computing the prompt embeddings. """# set lora scale so that monkey patched LoRA# function of text encoder can correctly access itiflora_scaleisnotNoneandisinstance(self,LoraLoaderMixin):self._lora_scale=lora_scale# dynamically adjust the LoRA scalescale_lora_layers(self.text_encoder,lora_scale)ifpromptisnotNoneandisinstance(prompt,str):batch_size=1elifpromptisnotNoneandisinstance(prompt,list):batch_size=len(prompt)else:batch_size=prompt_embeds.shape[0]ifprompt_embedsisNone:# textual inversion: process multi-vector tokens if necessaryifisinstance(self,TextualInversionLoaderMixin):prompt=self.maybe_convert_prompt(prompt,self.tokenizer)text_inputs=self.tokenizer(prompt,padding="max_length",max_length=self.tokenizer.model_max_length,truncation=True,return_tensors="np",)text_input_ids=text_inputs.input_idsuntruncated_ids=self.tokenizer(prompt,padding="longest",return_tensors="np").input_idsifuntruncated_ids.shape[-1]>=text_input_ids.shape[-1]andnotnp.array_equal(text_input_ids,untruncated_ids):removed_text=self.tokenizer.batch_decode(untruncated_ids[:,self.tokenizer.model_max_length-1:-1])logger.warning("The following part of your input was truncated because CLIP can only handle sequences up to"f" {self.tokenizer.model_max_length} tokens: {removed_text}")ifhasattr(self.text_encoder.config,"use_attention_mask")andself.text_encoder.config.use_attention_mask:attention_mask=ms.Tensor(text_inputs.attention_mask)else:attention_mask=Noneifclip_skipisNone:prompt_embeds=self.text_encoder(ms.Tensor(text_input_ids),attention_mask=attention_mask)prompt_embeds=prompt_embeds[0]else:prompt_embeds=self.text_encoder(ms.Tensor(text_input_ids),attention_mask=attention_mask,output_hidden_states=True)# Access the `hidden_states` first, that contains a tuple of# all the hidden states from the encoder layers. Then index into# the tuple to access the hidden states from the desired layer.prompt_embeds=prompt_embeds[-1][-(clip_skip+1)]# We also need to apply the final LayerNorm here to not mess with the# representations. The `last_hidden_states` that we typically use for# obtaining the final prompt representations passes through the LayerNorm# layer.prompt_embeds=self.text_encoder.text_model.final_layer_norm(prompt_embeds)ifself.text_encoderisnotNone:prompt_embeds_dtype=self.text_encoder.dtypeelifself.unetisnotNone:prompt_embeds_dtype=self.unet.dtypeelse:prompt_embeds_dtype=prompt_embeds.dtypeprompt_embeds=prompt_embeds.to(dtype=prompt_embeds_dtype)bs_embed,seq_len,_=prompt_embeds.shape# duplicate text embeddings for each generation per prompt, using mps friendly methodprompt_embeds=prompt_embeds.tile((1,num_images_per_prompt,1))prompt_embeds=prompt_embeds.view(bs_embed*num_images_per_prompt,seq_len,-1)# get unconditional embeddings for classifier free guidanceifdo_classifier_free_guidanceandnegative_prompt_embedsisNone:uncond_tokens:List[str]ifnegative_promptisNone:uncond_tokens=[""]*batch_sizeelifpromptisnotNoneandtype(prompt)isnottype(negative_prompt):raiseTypeError(f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="f" {type(prompt)}.")elifisinstance(negative_prompt,str):uncond_tokens=[negative_prompt]elifbatch_size!=len(negative_prompt):raiseValueError(f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"" the batch size of `prompt`.")else:uncond_tokens=negative_prompt# textual inversion: process multi-vector tokens if necessaryifisinstance(self,TextualInversionLoaderMixin):uncond_tokens=self.maybe_convert_prompt(uncond_tokens,self.tokenizer)max_length=prompt_embeds.shape[1]uncond_input=self.tokenizer(uncond_tokens,padding="max_length",max_length=max_length,truncation=True,return_tensors="np",)ifhasattr(self.text_encoder.config,"use_attention_mask")andself.text_encoder.config.use_attention_mask:attention_mask=ms.Tensor(uncond_input.attention_mask)else:attention_mask=Nonenegative_prompt_embeds=self.text_encoder(ms.Tensor(uncond_input.input_ids),attention_mask=attention_mask,)negative_prompt_embeds=negative_prompt_embeds[0]ifdo_classifier_free_guidance:# duplicate unconditional embeddings for each generation per prompt, using mps friendly methodseq_len=negative_prompt_embeds.shape[1]negative_prompt_embeds=negative_prompt_embeds.to(dtype=prompt_embeds_dtype)negative_prompt_embeds=negative_prompt_embeds.tile((1,num_images_per_prompt,1))negative_prompt_embeds=negative_prompt_embeds.view(batch_size*num_images_per_prompt,seq_len,-1)ifself.text_encoderisnotNone:ifisinstance(self,LoraLoaderMixin):# Retrieve the original scale by scaling back the LoRA layersunscale_lora_layers(self.text_encoder,lora_scale)returnprompt_embeds,negative_prompt_embedsdefencode_image(self,image,num_images_per_prompt,output_hidden_states=None):dtype=next(self.image_encoder.get_parameters()).dtypeifnotisinstance(image,ms.Tensor):image=self.feature_extractor(image,return_tensors="np").pixel_valuesimage=ms.Tensor(image)image=image.to(dtype=dtype)ifoutput_hidden_states:image_enc_hidden_states=self.image_encoder(image,output_hidden_states=True)[2][-2]image_enc_hidden_states=image_enc_hidden_states.repeat_interleave(num_images_per_prompt,dim=0)uncond_image_enc_hidden_states=self.image_encoder(ops.zeros_like(image),output_hidden_states=True)[2][-2]uncond_image_enc_hidden_states=uncond_image_enc_hidden_states.repeat_interleave(num_images_per_prompt,dim=0)returnimage_enc_hidden_states,uncond_image_enc_hidden_stateselse:image_embeds=self.image_encoder(image)[0]image_embeds=image_embeds.repeat_interleave(num_images_per_prompt,dim=0)uncond_image_embeds=ops.zeros_like(image_embeds)returnimage_embeds,uncond_image_embedsdefprepare_ip_adapter_image_embeds(self,ip_adapter_image,ip_adapter_image_embeds,num_images_per_prompt,do_classifier_free_guidance):ifip_adapter_image_embedsisNone:ifnotisinstance(ip_adapter_image,list):ip_adapter_image=[ip_adapter_image]iflen(ip_adapter_image)!=len(self.unet.encoder_hid_proj.image_projection_layers):raiseValueError(f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."# noqa: E501)image_embeds=[]forsingle_ip_adapter_image,image_proj_layerinzip(ip_adapter_image,self.unet.encoder_hid_proj.image_projection_layers):output_hidden_state=notisinstance(image_proj_layer,ImageProjection)single_image_embeds,single_negative_image_embeds=self.encode_image(single_ip_adapter_image,1,output_hidden_state)single_image_embeds=ops.stack([single_image_embeds]*num_images_per_prompt,axis=0)single_negative_image_embeds=ops.stack([single_negative_image_embeds]*num_images_per_prompt,axis=0)ifdo_classifier_free_guidance:single_image_embeds=ops.cat([single_negative_image_embeds,single_image_embeds])image_embeds.append(single_image_embeds)else:repeat_dims=[1]image_embeds=[]forsingle_image_embedsinip_adapter_image_embeds:ifdo_classifier_free_guidance:single_negative_image_embeds,single_image_embeds=single_image_embeds.chunk(2)single_image_embeds=single_image_embeds.tile((num_images_per_prompt,*(repeat_dims*len(single_image_embeds.shape[1:]))))single_negative_image_embeds=single_negative_image_embeds.tile((num_images_per_prompt,*(repeat_dims*len(single_negative_image_embeds.shape[1:]))))single_image_embeds=ops.cat([single_negative_image_embeds,single_image_embeds])else:single_image_embeds=single_image_embeds.tile((num_images_per_prompt,*(repeat_dims*len(single_image_embeds.shape[1:]))))image_embeds.append(single_image_embeds)returnimage_embedsdefrun_safety_checker(self,image,dtype):ifself.safety_checkerisNone:has_nsfw_concept=Noneelse:ifops.is_tensor(image):feature_extractor_input=self.image_processor.postprocess(image,output_type="pil")else:feature_extractor_input=self.image_processor.numpy_to_pil(image)safety_checker_input=self.feature_extractor(feature_extractor_input,return_tensors="np")image,has_nsfw_concept=self.safety_checker(images=image,clip_input=ms.Tensor(safety_checker_input.pixel_values).to(dtype))# Warning for safety checker operations here as it couldn't been done in construct()ifops.any(has_nsfw_concept):logger.warning("Potential NSFW content was detected in one or more images. A black image will be returned instead."" Try again with a different prompt and/or seed.")returnimage,has_nsfw_conceptdefdecode_latents(self,latents):deprecation_message="The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead"deprecate("decode_latents","1.0.0",deprecation_message,standard_warn=False)latents=1/self.vae.config.scaling_factor*latentsimage=self.vae.decode(latents,return_dict=False)[0]image=(image/2+0.5).clamp(0,1)# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16image=image.permute(0,2,3,1).float().numpy()returnimagedefprepare_extra_step_kwargs(self,generator,eta):# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502# and should be between [0, 1]accepts_eta="eta"inset(inspect.signature(self.scheduler.step).parameters.keys())extra_step_kwargs={}ifaccepts_eta:extra_step_kwargs["eta"]=eta# check if the scheduler accepts generatoraccepts_generator="generator"inset(inspect.signature(self.scheduler.step).parameters.keys())ifaccepts_generator:extra_step_kwargs["generator"]=generatorreturnextra_step_kwargsdefcheck_inputs(self,prompt,strength,callback_steps,negative_prompt=None,prompt_embeds=None,negative_prompt_embeds=None,ip_adapter_image=None,ip_adapter_image_embeds=None,callback_on_step_end_tensor_inputs=None,):ifstrength<0orstrength>1:raiseValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")ifcallback_stepsisnotNoneand(notisinstance(callback_steps,int)orcallback_steps<=0):raiseValueError(f"`callback_steps` has to be a positive integer but is {callback_steps} of type"f" {type(callback_steps)}.")ifcallback_on_step_end_tensor_inputsisnotNoneandnotall(kinself._callback_tensor_inputsforkincallback_on_step_end_tensor_inputs):raiseValueError(f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[kforkincallback_on_step_end_tensor_inputsifknotinself._callback_tensor_inputs]}"# noqa: E501)ifpromptisnotNoneandprompt_embedsisnotNone:raiseValueError(f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"" only forward one of the two.")elifpromptisNoneandprompt_embedsisNone:raiseValueError("Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined.")elifpromptisnotNoneand(notisinstance(prompt,str)andnotisinstance(prompt,list)):raiseValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")ifnegative_promptisnotNoneandnegative_prompt_embedsisnotNone:raiseValueError(f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"f" {negative_prompt_embeds}. Please make sure to only forward one of the two.")ifprompt_embedsisnotNoneandnegative_prompt_embedsisnotNone:ifprompt_embeds.shape!=negative_prompt_embeds.shape:raiseValueError("`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"f" {negative_prompt_embeds.shape}.")ifip_adapter_imageisnotNoneandip_adapter_image_embedsisnotNone:raiseValueError("Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined.")ifip_adapter_image_embedsisnotNone:ifnotisinstance(ip_adapter_image_embeds,list):raiseValueError(f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}")elifip_adapter_image_embeds[0].ndimnotin[3,4]:raiseValueError(f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D")defget_timesteps(self,num_inference_steps,strength):# get the original timestep using init_timestepinit_timestep=min(int(num_inference_steps*strength),num_inference_steps)t_start=max(num_inference_steps-init_timestep,0)timesteps=self.scheduler.timesteps[t_start*self.scheduler.order:]ifhasattr(self.scheduler,"set_begin_index"):self.scheduler.set_begin_index(t_start*self.scheduler.order)returntimesteps,num_inference_steps-t_startdefprepare_latents(self,image,timestep,batch_size,num_images_per_prompt,dtype,generator=None):ifnotisinstance(image,(ms.Tensor,PIL.Image.Image,list)):raiseValueError(f"`image` has to be of type `mindspore.Tensor`, `PIL.Image.Image` or list but is {type(image)}")image=image.to(dtype=dtype)batch_size=batch_size*num_images_per_promptifimage.shape[1]==4:init_latents=imageelse:ifisinstance(generator,list)andlen(generator)!=batch_size:raiseValueError(f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"f" size of {batch_size}. Make sure the batch size matches the length of the generators.")elifisinstance(generator,list):init_latents=[retrieve_latents(self.vae,self.vae.encode(image[i:i+1])[0],generator)foriinrange(batch_size)]init_latents=ops.cat(init_latents,axis=0)else:init_latents=retrieve_latents(self.vae,self.vae.encode(image)[0],generator)init_latents=self.vae.config.scaling_factor*init_latentsifbatch_size>init_latents.shape[0]andbatch_size%init_latents.shape[0]==0:# expand init_latents for batch_sizedeprecation_message=(f"You have passed {batch_size} text prompts (`prompt`), but only {init_latents.shape[0]} initial"" images (`image`). Initial images are now duplicating to match the number of text prompts. Note"" that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update"" your script to pass as many initial images as text prompts to suppress this warning.")deprecate("len(prompt) != len(image)","1.0.0",deprecation_message,standard_warn=False)additional_image_per_prompt=batch_size//init_latents.shape[0]init_latents=ops.cat([init_latents]*additional_image_per_prompt,axis=0)elifbatch_size>init_latents.shape[0]andbatch_size%init_latents.shape[0]!=0:raiseValueError(f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts.")else:init_latents=ops.cat([init_latents],axis=0)shape=init_latents.shapenoise=randn_tensor(shape,generator=generator,dtype=dtype)# get latentsinit_latents=self.scheduler.add_noise(init_latents,noise,timestep)latents=init_latentsreturnlatents# Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embeddingdefget_guidance_scale_embedding(self,w:ms.Tensor,embedding_dim:int=512,dtype:ms.Type=ms.float32)->ms.Tensor:""" See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298 Args: w (`ms.Tensor`): Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings. embedding_dim (`int`, *optional*, defaults to 512): Dimension of the embeddings to generate. dtype (`ms.dtype`, *optional*, defaults to `ms.float32`): Data type of the generated embeddings. Returns: `ms.Tensor`: Embedding vectors with shape `(len(w), embedding_dim)`. """assertlen(w.shape)==1w=w*1000.0half_dim=embedding_dim//2emb=ops.log(ms.tensor(10000.0))/(half_dim-1)emb=ops.exp(ops.arange(half_dim,dtype=dtype)*-emb)emb=w.to(dtype)[:,None]*emb[None,:]emb=ops.cat([ops.sin(emb),ops.cos(emb)],axis=1)ifembedding_dim%2==1:# zero pademb=ops.pad(emb,(0,1))assertemb.shape==(w.shape[0],embedding_dim)returnemb@propertydefguidance_scale(self):returnself._guidance_scale@propertydefclip_skip(self):returnself._clip_skip# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`# corresponds to doing no classifier free guidance.@propertydefdo_classifier_free_guidance(self):returnself._guidance_scale>1andself.unet.config.time_cond_proj_dimisNone@propertydefcross_attention_kwargs(self):returnself._cross_attention_kwargs@propertydefnum_timesteps(self):returnself._num_timesteps@propertydefinterrupt(self):returnself._interruptdef__call__(self,prompt:Union[str,List[str]]=None,image:PipelineImageInput=None,strength:float=0.8,num_inference_steps:int=50,timesteps:List[int]=None,sigmas:List[float]=None,guidance_scale:float=7.5,negative_prompt:Optional[Union[str,List[str]]]=None,num_images_per_prompt:Optional[int]=1,eta:float=0.0,generator:Optional[Union[np.random.Generator,List[np.random.Generator]]]=None,prompt_embeds:Optional[ms.Tensor]=None,negative_prompt_embeds:Optional[ms.Tensor]=None,ip_adapter_image:Optional[PipelineImageInput]=None,ip_adapter_image_embeds:Optional[List[ms.Tensor]]=None,output_type:Optional[str]="pil",return_dict:bool=False,cross_attention_kwargs:Optional[Dict[str,Any]]=None,clip_skip:Optional[int]=None,callback_on_step_end:Optional[Union[Callable[[int,int,Dict],None],PipelineCallback,MultiPipelineCallbacks]]=None,callback_on_step_end_tensor_inputs:List[str]=["latents"],**kwargs,):r""" The call function to the pipeline for generation. Args: prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. image (`ms.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[ms.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`): `Image`, numpy array or tensor representing an image batch to be used as the starting point. For both numpy array and mindspore tensor, the expected value range is between `[0, 1]` If it's a tensor or a list or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image latents as `image`, but if passing latents directly it is not encoded again. strength (`float`, *optional*, defaults to 0.8): Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a starting point and more noise is added the higher the `strength`. The number of denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising process runs for the full number of iterations specified in `num_inference_steps`. A value of 1 essentially ignores `image`. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. This parameter is modulated by `strength`. timesteps (`List[int]`, *optional*): Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed will be used. Must be in descending order. sigmas (`List[float]`, *optional*): Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed will be used. guidance_scale (`float`, *optional*, defaults to 7.5): A higher guidance scale value encourages the model to generate images closely linked to the text `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide what to not include in image generation. If not defined, you need to pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. generator (`np.random.Generator` or `List[np.random.Generator]`, *optional*): A [`np.random.Generator`](https://numpy.org/doc/stable/reference/random/generator.html) to make generation deterministic. prompt_embeds (`ms.Tensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not provided, text embeddings are generated from the `prompt` input argument. negative_prompt_embeds (`ms.Tensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. ip_adapter_image_embeds (`List[ms.Tensor]`, *optional*): Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not provided, embeddings are computed from the `ip_adapter_image` input argument. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generated image. Choose between `PIL.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `False`): Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a plain tuple. cross_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). clip_skip (`int`, *optional*): Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that the output of the pre-final layer will be used for computing the prompt embeddings. callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of each denoising step during the inference. with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by `callback_on_step_end_tensor_inputs`. callback_on_step_end_tensor_inputs (`List`, *optional*): The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the `._callback_tensor_inputs` attribute of your pipeline class. Examples: Returns: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, otherwise a `tuple` is returned where the first element is a list with the generated images and the second element is a list of `bool`s indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content. """callback=kwargs.pop("callback",None)callback_steps=kwargs.pop("callback_steps",None)ifcallbackisnotNone:deprecate("callback","1.0.0","Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",)ifcallback_stepsisnotNone:deprecate("callback_steps","1.0.0","Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",)ifisinstance(callback_on_step_end,(PipelineCallback,MultiPipelineCallbacks)):callback_on_step_end_tensor_inputs=callback_on_step_end.tensor_inputs# 1. Check inputs. Raise error if not correctself.check_inputs(prompt,strength,callback_steps,negative_prompt,prompt_embeds,negative_prompt_embeds,ip_adapter_image,ip_adapter_image_embeds,callback_on_step_end_tensor_inputs,)self._guidance_scale=guidance_scaleself._clip_skip=clip_skipself._cross_attention_kwargs=cross_attention_kwargsself._interrupt=False# 2. Define call parametersifpromptisnotNoneandisinstance(prompt,str):batch_size=1elifpromptisnotNoneandisinstance(prompt,list):batch_size=len(prompt)else:batch_size=prompt_embeds.shape[0]# 3. Encode input prompttext_encoder_lora_scale=(self.cross_attention_kwargs.get("scale",None)ifself.cross_attention_kwargsisnotNoneelseNone)prompt_embeds,negative_prompt_embeds=self.encode_prompt(prompt,num_images_per_prompt,self.do_classifier_free_guidance,negative_prompt,prompt_embeds=prompt_embeds,negative_prompt_embeds=negative_prompt_embeds,lora_scale=text_encoder_lora_scale,clip_skip=self.clip_skip,)# For classifier free guidance, we need to do two forward passes.# Here we concatenate the unconditional and text embeddings into a single batch# to avoid doing two forward passesifself.do_classifier_free_guidance:prompt_embeds=ops.cat([negative_prompt_embeds,prompt_embeds])ifip_adapter_imageisnotNoneorip_adapter_image_embedsisnotNone:image_embeds=self.prepare_ip_adapter_image_embeds(ip_adapter_image,ip_adapter_image_embeds,batch_size*num_images_per_prompt,self.do_classifier_free_guidance,)# 4. Preprocess imageimage=self.image_processor.preprocess(image)# 5. set timestepstimesteps,num_inference_steps=retrieve_timesteps(self.scheduler,num_inference_steps,timesteps,sigmas)timesteps,num_inference_steps=self.get_timesteps(num_inference_steps,strength)latent_timestep=timesteps[:1].tile((batch_size*num_images_per_prompt,))# 6. Prepare latent variableslatents=self.prepare_latents(image,latent_timestep,batch_size,num_images_per_prompt,prompt_embeds.dtype,generator,)# 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipelineextra_step_kwargs=self.prepare_extra_step_kwargs(generator,eta)# 7.1 Add image embeds for IP-Adapteradded_cond_kwargs=({"image_embeds":image_embeds}if(ip_adapter_imageisnotNoneorip_adapter_image_embedsisnotNone)elseNone)# 7.2 Optionally get Guidance Scale Embeddingtimestep_cond=Noneifself.unet.config.time_cond_proj_dimisnotNone:guidance_scale_tensor=ms.Tensor(self.guidance_scale-1).tile((batch_size*num_images_per_prompt))timestep_cond=self.get_guidance_scale_embedding(guidance_scale_tensor,embedding_dim=self.unet.config.time_cond_proj_dim).to(dtype=latents.dtype)# we're popping the `scale` instead of getting it because otherwise `scale` will be propagated# to the unet and will raise RuntimeError.lora_scale=self.cross_attention_kwargs.pop("scale",None)ifself.cross_attention_kwargsisnotNoneelseNoneiflora_scaleisnotNone:# weight the lora layers by setting `lora_scale` for each PEFT layerscale_lora_layers(self.unet,lora_scale)# 8. Denoising loopnum_warmup_steps=len(timesteps)-num_inference_steps*self.scheduler.orderself._num_timesteps=len(timesteps)withself.progress_bar(total=num_inference_steps)asprogress_bar:fori,tinenumerate(timesteps):ifself.interrupt:continue# expand the latents if we are doing classifier free guidancelatent_model_input=ops.cat([latents]*2)ifself.do_classifier_free_guidanceelselatents# TODO: method of scheduler should not change the dtype of input.# Remove the casting after cuiyushi confirm that.tmp_dtype=latent_model_input.dtypelatent_model_input=self.scheduler.scale_model_input(latent_model_input,t)latent_model_input=latent_model_input.to(tmp_dtype)# predict the noise residualnoise_pred=self.unet(latent_model_input,t,encoder_hidden_states=prompt_embeds,timestep_cond=timestep_cond,cross_attention_kwargs=self.cross_attention_kwargs,added_cond_kwargs=ms.mutable(added_cond_kwargs)ifadded_cond_kwargselseadded_cond_kwargs,return_dict=False,)[0]# perform guidanceifself.do_classifier_free_guidance:noise_pred_uncond,noise_pred_text=noise_pred.chunk(2)noise_pred=noise_pred_uncond+self.guidance_scale*(noise_pred_text-noise_pred_uncond)# compute the previous noisy sample x_t -> x_t-1# TODO: method of scheduler should not change the dtype of input.# Remove the casting after cuiyushi confirm that.tmp_dtype=latents.dtypelatents=self.scheduler.step(noise_pred,t,latents,**extra_step_kwargs,return_dict=False)[0]latents=latents.to(tmp_dtype)ifcallback_on_step_endisnotNone:callback_kwargs={}forkincallback_on_step_end_tensor_inputs:callback_kwargs[k]=locals()[k]callback_outputs=callback_on_step_end(self,i,t,callback_kwargs)latents=callback_outputs.pop("latents",latents)prompt_embeds=callback_outputs.pop("prompt_embeds",prompt_embeds)negative_prompt_embeds=callback_outputs.pop("negative_prompt_embeds",negative_prompt_embeds)# call the callback, if providedifi==len(timesteps)-1or((i+1)>num_warmup_stepsand(i+1)%self.scheduler.order==0):progress_bar.update()ifcallbackisnotNoneandi%callback_steps==0:step_idx=i//getattr(self.scheduler,"order",1)callback(step_idx,t,latents)iflora_scaleisnotNone:# remove `lora_scale` from each PEFT layerunscale_lora_layers(self.unet,lora_scale)ifnotoutput_type=="latent":latents=(latents/self.vae.config.scaling_factor).to(self.vae.dtype)image=self.vae.decode(latents,return_dict=False)[0]image,has_nsfw_concept=self.run_safety_checker(image,prompt_embeds.dtype)else:image=latentshas_nsfw_concept=Noneifhas_nsfw_conceptisNone:do_denormalize=[True]*image.shape[0]else:do_denormalize=[nothas_nsfwforhas_nsfwinhas_nsfw_concept]image=self.image_processor.postprocess(image,output_type=output_type,do_denormalize=do_denormalize)ifnotreturn_dict:return(image,has_nsfw_concept)returnStableDiffusionPipelineOutput(images=image,nsfw_content_detected=has_nsfw_concept)
The prompt or prompts to guide image generation. If not defined, you need to pass prompt_embeds.
TYPE:`str` or `List[str]`, *optional*DEFAULT:None
image
Image, numpy array or tensor representing an image batch to be used as the starting point. For both
numpy array and mindspore tensor, the expected value range is between [0, 1] If it's a tensor or a list
or tensors, the expected shape should be (B, C, H, W) or (C, H, W). If it is a numpy array or a
list of arrays, the expected shape should be (B, H, W, C) or (H, W, C) It can also accept image
latents as image, but if passing latents directly it is not encoded again.
TYPE:`ms.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[ms.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`DEFAULT:None
strength
Indicates extent to transform the reference image. Must be between 0 and 1. image is used as a
starting point and more noise is added the higher the strength. The number of denoising steps depends
on the amount of noise initially added. When strength is 1, added noise is maximum and the denoising
process runs for the full number of iterations specified in num_inference_steps. A value of 1
essentially ignores image.
TYPE:`float`, *optional*, defaults to 0.8DEFAULT:0.8
num_inference_steps
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference. This parameter is modulated by strength.
TYPE:`int`, *optional*, defaults to 50DEFAULT:50
timesteps
Custom timesteps to use for the denoising process with schedulers which support a timesteps argument
in their set_timesteps method. If not defined, the default behavior when num_inference_steps is
passed will be used. Must be in descending order.
TYPE:`List[int]`, *optional*DEFAULT:None
sigmas
Custom sigmas to use for the denoising process with schedulers which support a sigmas argument in
their set_timesteps method. If not defined, the default behavior when num_inference_steps is passed
will be used.
TYPE:`List[float]`, *optional*DEFAULT:None
guidance_scale
A higher guidance scale value encourages the model to generate images closely linked to the text
prompt at the expense of lower image quality. Guidance scale is enabled when guidance_scale > 1.
TYPE:`float`, *optional*, defaults to 7.5DEFAULT:7.5
negative_prompt
The prompt or prompts to guide what to not include in image generation. If not defined, you need to
pass negative_prompt_embeds instead. Ignored when not using guidance (guidance_scale < 1).
TYPE:`str` or `List[str]`, *optional*DEFAULT:None
num_images_per_prompt
The number of images to generate per prompt.
TYPE:`int`, *optional*, defaults to 1DEFAULT:1
eta
Corresponds to parameter eta (η) from the DDIM paper. Only applies
to the [~schedulers.DDIMScheduler], and is ignored in other schedulers.
TYPE:`float`, *optional*, defaults to 0.0DEFAULT:0.0
TYPE:`np.random.Generator` or `List[np.random.Generator]`, *optional*DEFAULT:None
prompt_embeds
Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
provided, text embeddings are generated from the prompt input argument.
TYPE:`ms.Tensor`, *optional*DEFAULT:None
negative_prompt_embeds
Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
not provided, negative_prompt_embeds are generated from the negative_prompt input argument.
TYPE:`ms.Tensor`, *optional*DEFAULT:None
ip_adapter_image
(PipelineImageInput, optional): Optional image input to work with IP Adapters.
TYPE:Optional[PipelineImageInput]DEFAULT:None
ip_adapter_image_embeds
Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of IP-adapters.
Each element should be a tensor of shape (batch_size, num_images, emb_dim). It should contain the negative image embedding
if do_classifier_free_guidance is set to True.
If not provided, embeddings are computed from the ip_adapter_image input argument.
TYPE:`List[ms.Tensor]`, *optional*DEFAULT:None
output_type
The output format of the generated image. Choose between PIL.Image or np.array.
TYPE:`str`, *optional*, defaults to `"pil"`DEFAULT:'pil'
return_dict
Whether or not to return a [~pipelines.stable_diffusion.StableDiffusionPipelineOutput] instead of a
plain tuple.
TYPE:`bool`, *optional*, defaults to `False`DEFAULT:False
cross_attention_kwargs
A kwargs dictionary that if specified is passed along to the [AttentionProcessor] as defined in
self.processor.
TYPE:`dict`, *optional*DEFAULT:None
clip_skip
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
the output of the pre-final layer will be used for computing the prompt embeddings.
TYPE:`int`, *optional*DEFAULT:None
callback_on_step_end
A function or a subclass of PipelineCallback or MultiPipelineCallbacks that is called at the end of
each denoising step during the inference. with the following arguments: callback_on_step_end(self:
DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict). callback_kwargs will include a
list of all tensors as specified by callback_on_step_end_tensor_inputs.
The list of tensor inputs for the callback_on_step_end function. The tensors specified in the list
will be passed as callback_kwargs argument. You will only be able to include variables listed in the
._callback_tensor_inputs attribute of your pipeline class.
TYPE:`List`, *optional*DEFAULT:['latents']
RETURNS
DESCRIPTION
[~pipelines.stable_diffusion.StableDiffusionPipelineOutput] or tuple:
If return_dict is True, [~pipelines.stable_diffusion.StableDiffusionPipelineOutput] is returned,
otherwise a tuple is returned where the first element is a list with the generated images and the
second element is a list of bools indicating whether the corresponding generated image contains
"not-safe-for-work" (nsfw) content.
Source code in mindone/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py
def__call__(self,prompt:Union[str,List[str]]=None,image:PipelineImageInput=None,strength:float=0.8,num_inference_steps:int=50,timesteps:List[int]=None,sigmas:List[float]=None,guidance_scale:float=7.5,negative_prompt:Optional[Union[str,List[str]]]=None,num_images_per_prompt:Optional[int]=1,eta:float=0.0,generator:Optional[Union[np.random.Generator,List[np.random.Generator]]]=None,prompt_embeds:Optional[ms.Tensor]=None,negative_prompt_embeds:Optional[ms.Tensor]=None,ip_adapter_image:Optional[PipelineImageInput]=None,ip_adapter_image_embeds:Optional[List[ms.Tensor]]=None,output_type:Optional[str]="pil",return_dict:bool=False,cross_attention_kwargs:Optional[Dict[str,Any]]=None,clip_skip:Optional[int]=None,callback_on_step_end:Optional[Union[Callable[[int,int,Dict],None],PipelineCallback,MultiPipelineCallbacks]]=None,callback_on_step_end_tensor_inputs:List[str]=["latents"],**kwargs,):r""" The call function to the pipeline for generation. Args: prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. image (`ms.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[ms.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`): `Image`, numpy array or tensor representing an image batch to be used as the starting point. For both numpy array and mindspore tensor, the expected value range is between `[0, 1]` If it's a tensor or a list or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image latents as `image`, but if passing latents directly it is not encoded again. strength (`float`, *optional*, defaults to 0.8): Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a starting point and more noise is added the higher the `strength`. The number of denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising process runs for the full number of iterations specified in `num_inference_steps`. A value of 1 essentially ignores `image`. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. This parameter is modulated by `strength`. timesteps (`List[int]`, *optional*): Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed will be used. Must be in descending order. sigmas (`List[float]`, *optional*): Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed will be used. guidance_scale (`float`, *optional*, defaults to 7.5): A higher guidance scale value encourages the model to generate images closely linked to the text `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide what to not include in image generation. If not defined, you need to pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. generator (`np.random.Generator` or `List[np.random.Generator]`, *optional*): A [`np.random.Generator`](https://numpy.org/doc/stable/reference/random/generator.html) to make generation deterministic. prompt_embeds (`ms.Tensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not provided, text embeddings are generated from the `prompt` input argument. negative_prompt_embeds (`ms.Tensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. ip_adapter_image_embeds (`List[ms.Tensor]`, *optional*): Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not provided, embeddings are computed from the `ip_adapter_image` input argument. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generated image. Choose between `PIL.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `False`): Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a plain tuple. cross_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). clip_skip (`int`, *optional*): Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that the output of the pre-final layer will be used for computing the prompt embeddings. callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of each denoising step during the inference. with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by `callback_on_step_end_tensor_inputs`. callback_on_step_end_tensor_inputs (`List`, *optional*): The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the `._callback_tensor_inputs` attribute of your pipeline class. Examples: Returns: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, otherwise a `tuple` is returned where the first element is a list with the generated images and the second element is a list of `bool`s indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content. """callback=kwargs.pop("callback",None)callback_steps=kwargs.pop("callback_steps",None)ifcallbackisnotNone:deprecate("callback","1.0.0","Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",)ifcallback_stepsisnotNone:deprecate("callback_steps","1.0.0","Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",)ifisinstance(callback_on_step_end,(PipelineCallback,MultiPipelineCallbacks)):callback_on_step_end_tensor_inputs=callback_on_step_end.tensor_inputs# 1. Check inputs. Raise error if not correctself.check_inputs(prompt,strength,callback_steps,negative_prompt,prompt_embeds,negative_prompt_embeds,ip_adapter_image,ip_adapter_image_embeds,callback_on_step_end_tensor_inputs,)self._guidance_scale=guidance_scaleself._clip_skip=clip_skipself._cross_attention_kwargs=cross_attention_kwargsself._interrupt=False# 2. Define call parametersifpromptisnotNoneandisinstance(prompt,str):batch_size=1elifpromptisnotNoneandisinstance(prompt,list):batch_size=len(prompt)else:batch_size=prompt_embeds.shape[0]# 3. Encode input prompttext_encoder_lora_scale=(self.cross_attention_kwargs.get("scale",None)ifself.cross_attention_kwargsisnotNoneelseNone)prompt_embeds,negative_prompt_embeds=self.encode_prompt(prompt,num_images_per_prompt,self.do_classifier_free_guidance,negative_prompt,prompt_embeds=prompt_embeds,negative_prompt_embeds=negative_prompt_embeds,lora_scale=text_encoder_lora_scale,clip_skip=self.clip_skip,)# For classifier free guidance, we need to do two forward passes.# Here we concatenate the unconditional and text embeddings into a single batch# to avoid doing two forward passesifself.do_classifier_free_guidance:prompt_embeds=ops.cat([negative_prompt_embeds,prompt_embeds])ifip_adapter_imageisnotNoneorip_adapter_image_embedsisnotNone:image_embeds=self.prepare_ip_adapter_image_embeds(ip_adapter_image,ip_adapter_image_embeds,batch_size*num_images_per_prompt,self.do_classifier_free_guidance,)# 4. Preprocess imageimage=self.image_processor.preprocess(image)# 5. set timestepstimesteps,num_inference_steps=retrieve_timesteps(self.scheduler,num_inference_steps,timesteps,sigmas)timesteps,num_inference_steps=self.get_timesteps(num_inference_steps,strength)latent_timestep=timesteps[:1].tile((batch_size*num_images_per_prompt,))# 6. Prepare latent variableslatents=self.prepare_latents(image,latent_timestep,batch_size,num_images_per_prompt,prompt_embeds.dtype,generator,)# 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipelineextra_step_kwargs=self.prepare_extra_step_kwargs(generator,eta)# 7.1 Add image embeds for IP-Adapteradded_cond_kwargs=({"image_embeds":image_embeds}if(ip_adapter_imageisnotNoneorip_adapter_image_embedsisnotNone)elseNone)# 7.2 Optionally get Guidance Scale Embeddingtimestep_cond=Noneifself.unet.config.time_cond_proj_dimisnotNone:guidance_scale_tensor=ms.Tensor(self.guidance_scale-1).tile((batch_size*num_images_per_prompt))timestep_cond=self.get_guidance_scale_embedding(guidance_scale_tensor,embedding_dim=self.unet.config.time_cond_proj_dim).to(dtype=latents.dtype)# we're popping the `scale` instead of getting it because otherwise `scale` will be propagated# to the unet and will raise RuntimeError.lora_scale=self.cross_attention_kwargs.pop("scale",None)ifself.cross_attention_kwargsisnotNoneelseNoneiflora_scaleisnotNone:# weight the lora layers by setting `lora_scale` for each PEFT layerscale_lora_layers(self.unet,lora_scale)# 8. Denoising loopnum_warmup_steps=len(timesteps)-num_inference_steps*self.scheduler.orderself._num_timesteps=len(timesteps)withself.progress_bar(total=num_inference_steps)asprogress_bar:fori,tinenumerate(timesteps):ifself.interrupt:continue# expand the latents if we are doing classifier free guidancelatent_model_input=ops.cat([latents]*2)ifself.do_classifier_free_guidanceelselatents# TODO: method of scheduler should not change the dtype of input.# Remove the casting after cuiyushi confirm that.tmp_dtype=latent_model_input.dtypelatent_model_input=self.scheduler.scale_model_input(latent_model_input,t)latent_model_input=latent_model_input.to(tmp_dtype)# predict the noise residualnoise_pred=self.unet(latent_model_input,t,encoder_hidden_states=prompt_embeds,timestep_cond=timestep_cond,cross_attention_kwargs=self.cross_attention_kwargs,added_cond_kwargs=ms.mutable(added_cond_kwargs)ifadded_cond_kwargselseadded_cond_kwargs,return_dict=False,)[0]# perform guidanceifself.do_classifier_free_guidance:noise_pred_uncond,noise_pred_text=noise_pred.chunk(2)noise_pred=noise_pred_uncond+self.guidance_scale*(noise_pred_text-noise_pred_uncond)# compute the previous noisy sample x_t -> x_t-1# TODO: method of scheduler should not change the dtype of input.# Remove the casting after cuiyushi confirm that.tmp_dtype=latents.dtypelatents=self.scheduler.step(noise_pred,t,latents,**extra_step_kwargs,return_dict=False)[0]latents=latents.to(tmp_dtype)ifcallback_on_step_endisnotNone:callback_kwargs={}forkincallback_on_step_end_tensor_inputs:callback_kwargs[k]=locals()[k]callback_outputs=callback_on_step_end(self,i,t,callback_kwargs)latents=callback_outputs.pop("latents",latents)prompt_embeds=callback_outputs.pop("prompt_embeds",prompt_embeds)negative_prompt_embeds=callback_outputs.pop("negative_prompt_embeds",negative_prompt_embeds)# call the callback, if providedifi==len(timesteps)-1or((i+1)>num_warmup_stepsand(i+1)%self.scheduler.order==0):progress_bar.update()ifcallbackisnotNoneandi%callback_steps==0:step_idx=i//getattr(self.scheduler,"order",1)callback(step_idx,t,latents)iflora_scaleisnotNone:# remove `lora_scale` from each PEFT layerunscale_lora_layers(self.unet,lora_scale)ifnotoutput_type=="latent":latents=(latents/self.vae.config.scaling_factor).to(self.vae.dtype)image=self.vae.decode(latents,return_dict=False)[0]image,has_nsfw_concept=self.run_safety_checker(image,prompt_embeds.dtype)else:image=latentshas_nsfw_concept=Noneifhas_nsfw_conceptisNone:do_denormalize=[True]*image.shape[0]else:do_denormalize=[nothas_nsfwforhas_nsfwinhas_nsfw_concept]image=self.image_processor.postprocess(image,output_type=output_type,do_denormalize=do_denormalize)ifnotreturn_dict:return(image,has_nsfw_concept)returnStableDiffusionPipelineOutput(images=image,nsfw_content_detected=has_nsfw_concept)
Encodes the prompt into text encoder hidden states.
PARAMETER
DESCRIPTION
prompt
prompt to be encoded
TYPE:`str` or `List[str]`, *optional*
num_images_per_prompt
number of images that should be generated per prompt
TYPE:`int`
do_classifier_free_guidance
whether to use classifier free guidance or not
TYPE:`bool`
negative_prompt
The prompt or prompts not to guide the image generation. If not defined, one has to pass
negative_prompt_embeds instead. Ignored when not using guidance (i.e., ignored if guidance_scale is
less than 1).
TYPE:`str` or `List[str]`, *optional*DEFAULT:None
prompt_embeds
Pre-generated text embeddings. Can be used to easily tweak text inputs, e.g. prompt weighting. If not
provided, text embeddings will be generated from prompt input argument.
TYPE:`ms.Tensor`, *optional*DEFAULT:None
negative_prompt_embeds
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, e.g. prompt
weighting. If not provided, negative_prompt_embeds will be generated from negative_prompt input
argument.
TYPE:`ms.Tensor`, *optional*DEFAULT:None
lora_scale
A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
TYPE:`float`, *optional*DEFAULT:None
clip_skip
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
the output of the pre-final layer will be used for computing the prompt embeddings.
TYPE:`int`, *optional*DEFAULT:None
Source code in mindone/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py
defencode_prompt(self,prompt,num_images_per_prompt,do_classifier_free_guidance,negative_prompt=None,prompt_embeds:Optional[ms.Tensor]=None,negative_prompt_embeds:Optional[ms.Tensor]=None,lora_scale:Optional[float]=None,clip_skip:Optional[int]=None,):r""" Encodes the prompt into text encoder hidden states. Args: prompt (`str` or `List[str]`, *optional*): prompt to be encoded num_images_per_prompt (`int`): number of images that should be generated per prompt do_classifier_free_guidance (`bool`): whether to use classifier free guidance or not negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). prompt_embeds (`ms.Tensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`ms.Tensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. lora_scale (`float`, *optional*): A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. clip_skip (`int`, *optional*): Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that the output of the pre-final layer will be used for computing the prompt embeddings. """# set lora scale so that monkey patched LoRA# function of text encoder can correctly access itiflora_scaleisnotNoneandisinstance(self,LoraLoaderMixin):self._lora_scale=lora_scale# dynamically adjust the LoRA scalescale_lora_layers(self.text_encoder,lora_scale)ifpromptisnotNoneandisinstance(prompt,str):batch_size=1elifpromptisnotNoneandisinstance(prompt,list):batch_size=len(prompt)else:batch_size=prompt_embeds.shape[0]ifprompt_embedsisNone:# textual inversion: process multi-vector tokens if necessaryifisinstance(self,TextualInversionLoaderMixin):prompt=self.maybe_convert_prompt(prompt,self.tokenizer)text_inputs=self.tokenizer(prompt,padding="max_length",max_length=self.tokenizer.model_max_length,truncation=True,return_tensors="np",)text_input_ids=text_inputs.input_idsuntruncated_ids=self.tokenizer(prompt,padding="longest",return_tensors="np").input_idsifuntruncated_ids.shape[-1]>=text_input_ids.shape[-1]andnotnp.array_equal(text_input_ids,untruncated_ids):removed_text=self.tokenizer.batch_decode(untruncated_ids[:,self.tokenizer.model_max_length-1:-1])logger.warning("The following part of your input was truncated because CLIP can only handle sequences up to"f" {self.tokenizer.model_max_length} tokens: {removed_text}")ifhasattr(self.text_encoder.config,"use_attention_mask")andself.text_encoder.config.use_attention_mask:attention_mask=ms.Tensor(text_inputs.attention_mask)else:attention_mask=Noneifclip_skipisNone:prompt_embeds=self.text_encoder(ms.Tensor(text_input_ids),attention_mask=attention_mask)prompt_embeds=prompt_embeds[0]else:prompt_embeds=self.text_encoder(ms.Tensor(text_input_ids),attention_mask=attention_mask,output_hidden_states=True)# Access the `hidden_states` first, that contains a tuple of# all the hidden states from the encoder layers. Then index into# the tuple to access the hidden states from the desired layer.prompt_embeds=prompt_embeds[-1][-(clip_skip+1)]# We also need to apply the final LayerNorm here to not mess with the# representations. The `last_hidden_states` that we typically use for# obtaining the final prompt representations passes through the LayerNorm# layer.prompt_embeds=self.text_encoder.text_model.final_layer_norm(prompt_embeds)ifself.text_encoderisnotNone:prompt_embeds_dtype=self.text_encoder.dtypeelifself.unetisnotNone:prompt_embeds_dtype=self.unet.dtypeelse:prompt_embeds_dtype=prompt_embeds.dtypeprompt_embeds=prompt_embeds.to(dtype=prompt_embeds_dtype)bs_embed,seq_len,_=prompt_embeds.shape# duplicate text embeddings for each generation per prompt, using mps friendly methodprompt_embeds=prompt_embeds.tile((1,num_images_per_prompt,1))prompt_embeds=prompt_embeds.view(bs_embed*num_images_per_prompt,seq_len,-1)# get unconditional embeddings for classifier free guidanceifdo_classifier_free_guidanceandnegative_prompt_embedsisNone:uncond_tokens:List[str]ifnegative_promptisNone:uncond_tokens=[""]*batch_sizeelifpromptisnotNoneandtype(prompt)isnottype(negative_prompt):raiseTypeError(f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="f" {type(prompt)}.")elifisinstance(negative_prompt,str):uncond_tokens=[negative_prompt]elifbatch_size!=len(negative_prompt):raiseValueError(f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"" the batch size of `prompt`.")else:uncond_tokens=negative_prompt# textual inversion: process multi-vector tokens if necessaryifisinstance(self,TextualInversionLoaderMixin):uncond_tokens=self.maybe_convert_prompt(uncond_tokens,self.tokenizer)max_length=prompt_embeds.shape[1]uncond_input=self.tokenizer(uncond_tokens,padding="max_length",max_length=max_length,truncation=True,return_tensors="np",)ifhasattr(self.text_encoder.config,"use_attention_mask")andself.text_encoder.config.use_attention_mask:attention_mask=ms.Tensor(uncond_input.attention_mask)else:attention_mask=Nonenegative_prompt_embeds=self.text_encoder(ms.Tensor(uncond_input.input_ids),attention_mask=attention_mask,)negative_prompt_embeds=negative_prompt_embeds[0]ifdo_classifier_free_guidance:# duplicate unconditional embeddings for each generation per prompt, using mps friendly methodseq_len=negative_prompt_embeds.shape[1]negative_prompt_embeds=negative_prompt_embeds.to(dtype=prompt_embeds_dtype)negative_prompt_embeds=negative_prompt_embeds.tile((1,num_images_per_prompt,1))negative_prompt_embeds=negative_prompt_embeds.view(batch_size*num_images_per_prompt,seq_len,-1)ifself.text_encoderisnotNone:ifisinstance(self,LoraLoaderMixin):# Retrieve the original scale by scaling back the LoRA layersunscale_lora_layers(self.text_encoder,lora_scale)returnprompt_embeds,negative_prompt_embeds
defget_guidance_scale_embedding(self,w:ms.Tensor,embedding_dim:int=512,dtype:ms.Type=ms.float32)->ms.Tensor:""" See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298 Args: w (`ms.Tensor`): Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings. embedding_dim (`int`, *optional*, defaults to 512): Dimension of the embeddings to generate. dtype (`ms.dtype`, *optional*, defaults to `ms.float32`): Data type of the generated embeddings. Returns: `ms.Tensor`: Embedding vectors with shape `(len(w), embedding_dim)`. """assertlen(w.shape)==1w=w*1000.0half_dim=embedding_dim//2emb=ops.log(ms.tensor(10000.0))/(half_dim-1)emb=ops.exp(ops.arange(half_dim,dtype=dtype)*-emb)emb=w.to(dtype)[:,None]*emb[None,:]emb=ops.cat([ops.sin(emb),ops.cos(emb)],axis=1)ifembedding_dim%2==1:# zero pademb=ops.pad(emb,(0,1))assertemb.shape==(w.shape[0],embedding_dim)returnemb
Source code in mindone/diffusers/pipelines/stable_diffusion/pipeline_output.py
10111213141516171819202122232425
@dataclassclassStableDiffusionPipelineOutput(BaseOutput):""" Output class for Stable Diffusion pipelines. Args: images (`List[PIL.Image.Image]` or `np.ndarray`) List of denoised PIL images of length `batch_size` or NumPy array of shape `(batch_size, height, width, num_channels)`. nsfw_content_detected (`List[bool]`) List indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content or `None` if safety checking could not be performed. """images:Union[List[PIL.Image.Image],np.ndarray]nsfw_content_detected:Optional[List[bool]]