Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Threed #7

Draft
wants to merge 5 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file added .gitmodules
Empty file.
2 changes: 1 addition & 1 deletion .vscode/launch.json
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
{"name":"Python Debugger: Current File","type":"debugpy","request":"launch","program":"${file}","console":"integratedTerminal"},
{
"name": "Python Debugger: Current File",
"type": "python",
"type": "debugpy",
"request": "launch",
"program": "${file}",
"python": "/Users/gleb/Library/Caches/pypoetry/virtualenvs/drawingwithgaussians-3hH70u0l-py3.11/bin/python",
Expand Down
4 changes: 2 additions & 2 deletions Dockerfile
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
FROM ghcr.io/nvidia/jax:jax

RUN apt-get update && apt-get install ffmpeg libsm6 libxext6 -y
RUN pip install numpy pillow matplotlib opencv-python einops optax hydra-core dm-pix flax diffusers transformers ftfy
RUN apt-get update && apt-get install ffmpeg libsm6 libxext6 python3.10-venv -y
RUN apt remove cmake -y && pip install cmake --upgrade && pip install "pybind11[global]" poetry
8 changes: 5 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ This is not a "production-ready project" by any means but rather my attempts at
## Fit 2D gaussians to an image

```bash
python fit.py --config-name fit_to_image.yaml
python fit_2d.py --config-name fit_to_image.yaml
```

![An example of fitting an image](./static/eye_fitting.gif)
Expand All @@ -27,7 +27,7 @@ Here I initialize 50 gaussians and split them every epoch based on the gradient
## Fit 2D gaussians with a diffusion prior

```bash
python fit.py --config-name diffusion_guidance.yaml
python fit_2d.py --config-name diffusion_guidance.yaml
```

Prompts is `A man standing on the street`, it nicely detereorates in an abstract image probably due to the low number of gaussians (50->88->159->256->378), annealing is on, img2img every 50 steps
Expand All @@ -52,8 +52,10 @@ A bit about the config:
- [ ] Ability to copy optimizer state from before the pruning (copy for the splitted gaussians)
- [ ] Test "deferred rendering" like in [SpacetimeGaussians](https://oppo-us-research.github.io/SpacetimeGaussians-website/)
- [x] Add SDS with SD
- [ ] Add basic 3D version
- [x] Add basic 3D version
- [ ] Add proper camera sampling (e.g. take from [vispy](https://vispy.org/api/vispy.scene.cameras.arcball.html))
- [ ] Add alternative alpha-composing with occlusions (prune gaussians based on opacity, currently prunning based on color norm, probably won't do this untill I'll decide to move to 3D)

## References
Based on [3DGS](https://repo-sam.inria.fr/fungraph/3d-gaussian-splatting/), [fmb-plus](https://leonidk.com/fmb-plus/), [GaussianImage](https://arxiv.org/abs/2403.08551), works ok on macbook m1 up to ~300 gaussians
Fast 3dgs borrowed from [jaxsplat](https://github.com/yklcs/jaxsplat)
1 change: 0 additions & 1 deletion drawingwithgaussians/networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import jax
import jax.numpy as jnp
from flax import linen as nn
from jax import lax, random


def posenc(x, min_deg, max_deg, legacy_posenc_order=False):
Expand Down
2 changes: 2 additions & 0 deletions drawingwithgaussians/threed/camera.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
import jax.numpy as jnp
from jax.scipy.spatial.transform import Rotation as R
106 changes: 106 additions & 0 deletions drawingwithgaussians/threed/losses.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
from functools import partial

import jax
import jax.numpy as jnp
import jaxsplat

from drawingwithgaussians.twod.sds_pipeline import img2img


@jax.jit
def pixel_loss(
target,
means3d,
scales,
quats,
colors,
opacities,
viewmat,
background,
focal,
center,
shape,
glob_scale,
clip_thresh,
block_size,
):
renderred_gaussians = jaxsplat.render(
means3d, # jax.Array (N, 3)
scales, # jax.Array (N, 3)
quats, # jax.Array (N, 4) normalized
colors, # jax.Array (N, 3)
opacities, # jax.Array (N, 1)
viewmat=viewmat, # jax.Array (4, 4)
background=background, # jax.Array (3,)
img_shape=shape, # tuple[int, int] = (H, W)
f=focal, # tuple[float, float] = (fx, fy)
c=center, # tuple[int, int] = (cx, cy)
glob_scale=glob_scale, # float
clip_thresh=clip_thresh, # float
block_size=block_size, # int <= 16
)
loss = jnp.mean(jnp.square(renderred_gaussians - target))
return loss, renderred_gaussians


def diffusion_guidance(
gaussians_params,
viewmat,
background,
focal,
center,
shape,
glob_scale,
clip_thresh,
block_size,
diffusion_shape,
num_steps,
strength,
pipeline,
params,
dtype,
cfg_scale,
prompt,
key,
target_image=None,
):
means3d, scales, quats, colors, opacities = gaussians_params

renderred_gaussians = jaxsplat.render(
means3d, # jax.Array (N, 3)
scales, # jax.Array (N, 3)
quats, # jax.Array (N, 4) normalized
colors, # jax.Array (N, 3)
opacities, # jax.Array (N, 1)
viewmat=viewmat, # jax.Array (4, 4)
background=background, # jax.Array (3,)
img_shape=shape, # tuple[int, int] = (H, W)
f=focal, # tuple[float, float] = (fx, fy)
c=center, # tuple[int, int] = (cx, cy)
glob_scale=glob_scale, # float
clip_thresh=clip_thresh, # float
block_size=block_size, # int <= 16
)

height_d, width_d, c = diffusion_shape
if target_image is None:
image = jax.lax.stop_gradient(
img2img(
jax.lax.stop_gradient(renderred_gaussians.astype(dtype)),
prompt,
key,
height_d,
width_d,
num_steps,
strength,
cfg_scale,
pipeline,
params,
)
)
shape = (shape[0], shape[1], 3)
image = jax.image.resize(image[0, 0], shape=shape, method="bilinear")
else:
image = jnp.copy(target_image)
loss = jnp.mean(jnp.square(renderred_gaussians - jax.lax.stop_gradient(image)))
return loss, (renderred_gaussians, image)
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from functools import partial

import jax
import jax.numpy as jnp
import optax
Expand All @@ -14,7 +12,6 @@ def init_gaussians(num_gaussians, target_image, key, optimize_background=True):
if not optimize_background:
background_color *= 0.0
height, weight, _ = target_image.shape
target_image = target_image
means = jax.random.uniform(key, (num_gaussians, 2), minval=0, maxval=height, dtype=jnp.float32)
sigmas = jax.random.uniform(key, (num_gaussians, 2), minval=1, maxval=height / 8, dtype=jnp.float32)
covariances = jnp.stack([jnp.diag(sigma**2) for sigma in sigmas])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,8 @@

import jax
import jax.numpy as jnp
import numpy as np
from dm_pix import ssim
from jax.experimental import io_callback

from .rendering2d import rasterize
from .rendering import rasterize
from .sds_pipeline import img2img


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

import jax
import jax.numpy as jnp
import numpy as np
from diffusers.models import FlaxAutoencoderKL, FlaxUNet2DConditionModel
from diffusers.pipelines.pipeline_flax_utils import FlaxDiffusionPipeline
from diffusers.pipelines.stable_diffusion import FlaxStableDiffusionPipelineOutput
Expand All @@ -16,12 +15,11 @@
FlaxLMSDiscreteScheduler,
FlaxPNDMScheduler,
)
from diffusers.utils import PIL_INTERPOLATION, logging, replace_example_docstring
from diffusers.utils import logging
from einops import rearrange, repeat
from flax.core.frozen_dict import FrozenDict
from flax.jax_utils import replicate, unreplicate
from flax.jax_utils import replicate
from flax.training.common_utils import shard
from jax import pmap
from PIL import Image
from transformers import CLIPImageProcessor, CLIPTokenizer, FlaxCLIPTextModel

Expand Down Expand Up @@ -56,64 +54,6 @@ def img2img(image, prompt, key, height, width, num_steps, strength, cfg_scale, p
# Set to True to use python for loop instead of jax.fori_loop for easier debugging
DEBUG = False

EXAMPLE_DOC_STRING = """
Examples:
```py
>>> import jax
>>> import numpy as np
>>> import jax.numpy as jnp
>>> from flax.jax_utils import replicate
>>> from flax.training.common_utils import shard
>>> import requests
>>> from io import BytesIO
>>> from PIL import Image
>>> from diffusers import FlaxStableDiffusionImg2ImgPipeline


>>> def create_key(seed=0):
... return jax.random.PRNGKey(seed)


>>> rng = create_key(0)

>>> url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg"
>>> response = requests.get(url)
>>> init_img = Image.open(BytesIO(response.content)).convert("RGB")
>>> init_img = init_img.resize((768, 512))

>>> prompts = "A fantasy landscape, trending on artstation"

>>> pipeline, params = FlaxStableDiffusionImg2ImgPipeline.from_pretrained(
... "CompVis/stable-diffusion-v1-4",
... revision="flax",
... dtype=jnp.bfloat16,
... )

>>> num_samples = jax.device_count()
>>> rng = jax.random.split(rng, jax.device_count())
>>> prompt_ids, processed_image = pipeline.prepare_inputs(
... prompt=[prompts] * num_samples, image=[init_img] * num_samples
... )
>>> p_params = replicate(params)
>>> prompt_ids = shard(prompt_ids)
>>> processed_image = shard(processed_image)

>>> output = pipeline(
... prompt_ids=prompt_ids,
... image=processed_image,
... params=p_params,
... prng_seed=rng,
... strength=0.75,
... num_inference_steps=50,
... jit=True,
... height=512,
... width=768,
... ).images

>>> output_images = pipeline.numpy_to_pil(np.asarray(output.reshape((num_samples,) + output.shape[-3:])))
```
"""


class FlaxStableDiffusionImg2ImgPipeline(FlaxDiffusionPipeline):
r"""
Expand Down Expand Up @@ -298,7 +238,6 @@ def loop_body(step, args):
image = (image / 2 + 0.5).clip(0, 1).transpose(0, 2, 3, 1)
return image

@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
self,
prompt_ids: jnp.ndarray,
Expand Down
6 changes: 3 additions & 3 deletions fit.py → fit_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@
from omegaconf import DictConfig, OmegaConf
from PIL import Image

from drawingwithgaussians.gaussian import init_gaussians, set_up_optimizers, split_n_prune, update
from drawingwithgaussians.losses import diffusion_guidance, pixel_loss
from drawingwithgaussians.sds_pipeline import FlaxStableDiffusionImg2ImgPipeline
from drawingwithgaussians.twod.gaussian import init_gaussians, set_up_optimizers, split_n_prune, update
from drawingwithgaussians.twod.losses import diffusion_guidance, pixel_loss
from drawingwithgaussians.twod.sds_pipeline import FlaxStableDiffusionImg2ImgPipeline


@hydra.main(version_base=None, config_path="./configs")
Expand Down
Loading