Traceback (most recent call last): File "/content/train_dreambooth.py", line 21, in from diffusers import AutoencoderKL, DDIMScheduler, DDPMScheduler, StableDiffusionPipeline, UNet2DConditionModel File "/usr/local/lib/python3.10/dist-packages/diffusers/__init__.py", line 36, in from .models import ( File "/usr/local/lib/python3.10/dist-packages/diffusers/models/__init__.py", line 33, in from .controlnet_flax import FlaxControlNetModel File "/usr/local/lib/python3.10/dist-packages/diffusers/models/controlnet_flax.py", line 25, in from .modeling_flax_utils import FlaxModelMixin File "/usr/local/lib/python3.10/dist-packages/diffusers/models/modeling_flax_utils.py", line 45, in class FlaxModelMixin: File "/usr/local/lib/python3.10/dist-packages/diffusers/models/modeling_flax_utils.py", line 192, in FlaxModelMixin def init_weights(self, rng: jax.random.KeyArray) -> Dict: File "/usr/local/lib/python3.10/dist-packages/jax/_src/deprecations.py", line 54, in getattr raise AttributeError(f"module {module!r} has no attribute {name!r}") AttributeError: module 'jax.random' has no attribute 'KeyArray'