基于onnx模型和onnx runtime推理stable diffusion

直接用diffusers的pipeline:

# make sure you're logged in with `huggingface-cli login`
from diffusers import StableDiffusionOnnxPipeline

pipe = StableDiffusionOnnxPipeline.from_pretrained(
    "runwayml/stable-diffusion-v1-5",
    revision="onnx",
    provider="CUDAExecutionProvider",
)

prompt = "a photo of an astronaut riding a horse on mars"
image = pipe(prompt).images[0]

在pipeline_onnx_stable_diffusion的基础上修改得到的直接调用onnx模型版本,可以用于其他推理引擎推理参考:

pipe_onnx_sim.py

# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import inspect
from typing import Callable, List, Optional, Union

import numpy as np
import torch
from transformers import CLIPTokenizer
from diffusers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler, DPMSolverMultistepScheduler
from diffusers import AutoencoderKL, UNet2DConditionModel
from transformers import CLIPTextModel, CLIPTokenizer

from onnx_utils import OnnxRuntimeModel
import logging as logger
from tqdm.auto import tqdm
from PIL import Image


class OnnxStableDiffusionPipeline():

    def __init__(self, model_dir):
        # scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],

        # stable-diffusion-v1-5 use PNDMScheduler by default
        self.scheduler = PNDMScheduler.from_pretrained(os.path.join(model_dir, "scheduler/scheduler_config.json"))

        # stable-diffusion-2-1 use DDIMScheduler by default
        # self.scheduler = DDIMScheduler.from_pretrained(os.path.join(model_dir, "scheduler/scheduler_config.json"))

        '''
        self.scheduler = DPMSolverMultistepScheduler(
            beta_start=0.00085,
            beta_end=0.012,
            beta_schedule="scaled_linear",
            num_train_timesteps=1000,
            trained_betas=None,
            predict_epsilon=True,
            thresholding=False,
            algorithm_type="dpmsolver++",
            solver_type="midpoint",
            lower_order_final=True,
        )
        '''

        self.tokenizer = CLIPTokenizer.from_pretrained(model_dir, subfolder="tokenizer")

        self.text_encoder = OnnxRuntimeModel(os.path.join(model_dir, "text_encoder/model.onnx"))

        # in txt to image, vae_encoder is not necessary, only used in image to image generation
        # self.vae_encoder = OnnxRuntimeModel(os.path.join(model_dir, "vae_encoder/model.onnx"))

        self.vae_decoder = OnnxRuntimeModel(os.path.join(model_dir, "vae_decoder/model.onnx"))
        self.unet = OnnxRuntimeModel(os.path.join(model_dir, "unet/model.onnx"))

        # check and reset scheduler para
        if hasattr(self.scheduler.config, "steps_offset") and self.scheduler.config.steps_offset != 1:
            deprecation_message = (
                f"The configuration file of this scheduler: {self.scheduler} is outdated. `steps_offset`"
                f" should be set to 1 instead of {self.scheduler.config.steps_offset}. Please make sure "
                "to update the config accordingly as leaving `steps_offset` might led to incorrect results"
                " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
                " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
                " file"
            )
            print("steps_offset!=1", "1.0.0", deprecation_message)
            # new_config = dict(scheduler.config)
            # new_config["steps_offset"] = 1
            # scheduler._internal_dict = FrozenDict(new_config)

        if hasattr(self.scheduler.config, "clip_sample") and self.scheduler.config.clip_sample is True:
            deprecation_message = (
                f"The configuration file of this scheduler: {self.scheduler} has not set the configuration `clip_sample`."
                " `clip_sample` should be set to False in the configuration file. Please make sure to update the"
                " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
                " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
                " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
            )
            print("clip_sample not set", "1.0.0", deprecation_message)
            # new_config = dict(scheduler.config)
            # new_config["clip_sample"] = False
            # scheduler._internal_dict = FrozenDict(new_config)

    def __call__(
        self,
        prompt: str,
        height: Optional[int] = 512,
        width: Optional[int] = 512,
        num_inference_steps: Optional[int] = 25,
        guidance_scale: Optional[float] = 7.5,
        negative_prompt: Optional[Union[str, List[str]]] = None,
        num_images_per_prompt: Optional[int] = 1,
        eta: Optional[float] = 0.0,
        latents: Optional[np.ndarray] = None,
        output_type: Optional[str] = "pil",
    ):
        if isinstance(prompt, str):
            batch_size = 1
        elif isinstance(prompt, list):
            batch_size = len(prompt)

        if height % 8 != 0 or width % 8 != 0:
            raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")

        # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
        # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
        # corresponds to doing no classifier free guidance.
        do_classifier_free_guidance = guidance_scale > 1.0

        print("do_classifier_free_guidance:", do_classifier_free_guidance)

        prompt_embeds = self._encode_prompt(
            prompt, num_images_per_prompt, do_classifier_free_guidance, negative_prompt
        )
        prompt_embeds = prompt_embeds.astype("float32")
        print("prompt_embeds:", prompt_embeds, prompt_embeds.shape)

        # get the initial random noise unless the user supplied it
        latents_dtype = prompt_embeds.dtype
        latents_shape = (batch_size * num_images_per_prompt, 4, height // 8, width // 8)

        if latents is None:
            latents = np.random.randn(*latents_shape).astype(latents_dtype)
        elif latents.shape != latents_shape:
            raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")

        # set timesteps
        self.scheduler.set_timesteps(num_inference_steps)

        latents = latents * np.float32(self.scheduler.init_noise_sigma)

        # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
        # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
        # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
        # and should be between [0, 1]
        accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())

        extra_step_kwargs = {}
        if accepts_eta:
            extra_step_kwargs["eta"] = eta

        timestep_dtype = np.int64

        print("timesteps:", self.scheduler.timesteps)
        timesteps_tensor = self.scheduler.timesteps

        for i, t in enumerate(tqdm(timesteps_tensor)):
            # expand the latents if we are doing classifier free guidance
            latent_model_input = np.concatenate([latents] * 2) if do_classifier_free_guidance else latents
            latent_model_input = self.scheduler.scale_model_input(torch.from_numpy(latent_model_input), t)
            latent_model_input = latent_model_input.cpu().numpy()
            latent_model_input = latent_model_input.astype("float32")

            # predict the noise residual
            timestep = np.array([t], dtype=timestep_dtype)

            # [2, 4, 64, 64]
            print("unet in shape:", latent_model_input.shape, timestep.shape, prompt_embeds.shape)

            noise_pred = self.unet(sample=latent_model_input, timestep=timestep,
                                   encoder_hidden_states=prompt_embeds)
            noise_pred = noise_pred[0]

            # perform guidance
            if do_classifier_free_guidance:
                noise_pred_uncond, noise_pred_text = np.split(noise_pred, 2)
                noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)

            # compute the previous noisy sample x_t -> x_t-1
            scheduler_output = self.scheduler.step(
                torch.from_numpy(noise_pred), t, torch.from_numpy(latents), **extra_step_kwargs
            )
            latents = scheduler_output.prev_sample.numpy()

        latents = 1 / 0.18215 * latents

        print("latents:", latents.shape)
        # for batch = 1
        image = self.vae_decoder(latent_sample=latents)[0]

        # it seems likes there is a strange result for using half-precision vae decoder if batchsize>1
        # image = np.concatenate(
        #     [self.vae_decoder(latent_sample=latents[i: i + 1])[0] for i in range(latents.shape[0])]
        # )

        image = np.clip(image / 2 + 0.5, 0, 1)
        image = image.transpose((0, 2, 3, 1))  # out shape: (1, 512, 512, 3)
        print("image:", image.shape)

        return image

    def _encode_prompt(self, prompt, num_images_per_prompt, do_classifier_free_guidance, negative_prompt):
        r"""
        Encodes the prompt into text encoder hidden states.

        Args:
            prompt (`str` or `List[str]`):
                prompt to be encoded
            num_images_per_prompt (`int`):
                number of images that should be generated per prompt
            do_classifier_free_guidance (`bool`):
                whether to use classifier free guidance or not
            negative_prompt (`str` or `List[str]`):
                The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
                if `guidance_scale` is less than `1`).
        """
        batch_size = 1

        # get prompt text embeddings
        text_inputs = self.tokenizer(
            prompt,
            padding="max_length",
            max_length=self.tokenizer.model_max_length,
            truncation=True,
            return_tensors="np",
        )

        text_input_ids = text_inputs.input_ids

        # check inputs are truncated
        # untruncated_ids = self.tokenizer(prompt, padding="max_length", return_tensors="np").input_ids

        # if not np.array_equal(text_input_ids, untruncated_ids):
        #     removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1: -1])
        #     logger.warning(
        #         "The following part of your input was truncated because CLIP can only handle sequences up to"
        #         f" {self.tokenizer.model_max_length} tokens: {removed_text}"
        #     )

        prompt_embeds = self.text_encoder(input_ids=text_input_ids.astype(np.int32))[0]
        # gen multi images per prompt
        # prompt_embeds = np.repeat(prompt_embeds, num_images_per_prompt, axis=0)

        # get unconditional embeddings for classifier free guidance
        if do_classifier_free_guidance:
            uncond_tokens: List[str]
            if negative_prompt is None:
                uncond_tokens = [""] * batch_size
            else:
                uncond_tokens = negative_prompt

            max_length = text_input_ids.shape[-1]
            uncond_input = self.tokenizer(
                uncond_tokens,
                padding="max_length",
                max_length=max_length,
                truncation=True,
                return_tensors="np",
            )
            negative_prompt_embeds = self.text_encoder(input_ids=uncond_input.input_ids.astype(np.int32))[0]
            negative_prompt_embeds = np.repeat(negative_prompt_embeds, num_images_per_prompt, axis=0)

            # For classifier free guidance, we need to do two forward passes.
            # Here we concatenate the unconditional and text embeddings into a single batch
            # to avoid doing two forward passes
            prompt_embeds = np.concatenate([negative_prompt_embeds, prompt_embeds])

        return prompt_embeds

    @staticmethod
    def numpy_to_pil(images):
        """
        Convert a numpy image or a batch of images to a PIL image.
        """
        if images.ndim == 3:
            images = images[None, ...]
        images = (images * 255).round().astype("uint8")
        if images.shape[-1] == 1:
            # special case for grayscale (single channel) images
            pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images]
        else:
            pil_images = [Image.fromarray(image) for image in images]

        return pil_images


model_dir = "stable-diffusion-v1-5"
onnx_pipe = OnnxStableDiffusionPipeline(model_dir)
prompt = "beautiful victorian raven digital painting, art by artgerm and greg rutkowski, alphonse mucha, cgsociety"
image = onnx_pipe(prompt)

images = onnx_pipe.numpy_to_pil(image)
for i, image in enumerate(images):
    image.save(f"generated_image_{i}.png")

onnx_utils.py

import logging as logger
import numpy as np
import os
import onnxruntime as ort


ORT_TO_NP_TYPE = {
    "tensor(bool)": np.bool_,
    "tensor(int8)": np.int8,
    "tensor(uint8)": np.uint8,
    "tensor(int16)": np.int16,
    "tensor(uint16)": np.uint16,
    "tensor(int32)": np.int32,
    "tensor(uint32)": np.uint32,
    "tensor(int64)": np.int64,
    "tensor(uint64)": np.uint64,
    "tensor(float16)": np.float16,
    "tensor(float)": np.float32,
    "tensor(double)": np.float64,
}


class OnnxRuntimeModel:
    def __init__(self, model_path, **kwargs):
        self.model = None
        if model_path:
            self.load_model(model_path)

    def __call__(self, **kwargs):
        inputs = {k: np.array(v) for k, v in kwargs.items()}
        return self.model.run(None, inputs)

    def load_model(self, path: str, provider=None, sess_options=None):
        """
        Loads an ONNX Inference session with an ExecutionProvider. Default provider is `CPUExecutionProvider`

        Arguments:
            path (`str` or `Path`):
                Directory from which to load
            provider(`str`, *optional*):
                Onnxruntime execution provider to use for loading the model, defaults to `CPUExecutionProvider`
        """
        if provider is None:
            logger.info("No onnxruntime provider specified, using CPUExecutionProvider")
            provider = "CPUExecutionProvider"

        self.model = ort.InferenceSession(path, providers=[provider], sess_options=sess_options)

你可能感兴趣的:(stable,diffusion,人工智能,深度学习)