Skip to content

[Arxiv 2024] From Parts to Whole: A Unified Reference Framework for Controllable Human Image Generation

License

Notifications You must be signed in to change notification settings

huanngzh/Parts2Whole

Repository files navigation

Parts2Whole

[Arxiv 2024] From Parts to Whole: A Unified Reference Framework for Controllable Human Image Generation

  • Inference code and pretrained models.
  • Evaluation code.
  • Training code.
  • Training data.
  • New model based on Stable Diffusion 2-1.

🔥 Updates

[2024-06-21] Release our training code. Refer to Training.
[2024-05-26] Dataset is released here. Refer to Dataset.
[2024-05-06] 🔥🔥🔥 Code is released. Enjoy the human parts composition!

img:teaser

Abstract: We propose Parts2Whole, a novel framework designed for generating customized portraits from multiple reference images, including pose images and various aspects of human appearance. We first develop a semantic-aware appearance encoder to retain details of different human parts, which processes each image based on its textual label to a series of multi-scale feature maps rather than one image token, preserving the image dimension. Second, our framework supports multi-image conditioned generation through a shared self-attention mechanism that operates across reference and target features during the diffusion process. We enhance the vanilla attention mechanism by incorporating mask information from the reference human images, allowing for precise selection of any part.

🔨 Method Overview

img:pipeline

⚒️ Installation

Clone our repo, and install packages in requirements.txt. We test our model on a 80G A800 GPU with 11.8 CUDA and 2.0.1 PyTorch. But inference on smaller GPUs is possible.

conda create -n parts2whole
conda activate parts2whole
pip install -r requirements.txt

Download checkpoints here into pretrained_weights/parts2whole dir. We also provide a simple download script, using:

python download_weights.py

🎨 Inference

Check inference.py. Modify the checkpoint path and input as you need, and run command:

python inference.py

You may need to modify the following code in the inference.py script:

### Define configurations ###
device = "cuda"
torch_dtype = torch.float16
seed = 42
model_dir = "pretrained_weights/parts2whole"  # checkpoint path in your local machine
use_decoupled_cross_attn = True
decoupled_cross_attn_path = "pretrained_weights/parts2whole/decoupled_attn.pth" # include in the model_dir
### Define input data ###
height, width = 768, 512
prompt = "This person is wearing a short-sleeve shirt." # input prompt
input_dict = {
    "appearance": {
        "face": "testset/face_man1.jpg",
        "whole body clothes": "testset/clothes_man1.jpg",
    },
    "mask": {
        "face": "testset/face_man1_mask.jpg",
        "whole body clothes": "testset/clothes_man1_mask.jpg",
    },
    "structure": {"densepose": "testset/densepose_man1.jpg"},
}

⭐️⭐️⭐️ Notably, the input_dict should contain keys appearance, mask, and structure. The first two mean specifying the appearance of parts of multiple reference images, and structure means postures such as densepose.

⭐️⭐️⭐️ The keys in these three parts also have explanations. Keys in appearance and mask should be the same. The choices include "upper body clothes", "lower body clothes", "whole body clothes", "hair or headwear", "face", "shoes". Key of structure should be "densepose". (The openpose model has not been release.)

🔨🔨🔨 In order to conveniently obtain the mask of each reference image, we also provide corresponding tools and explain how to use them in Tools. First, you can use Real-ESRGAN to increase the resolution of the reference image, and use segformer to obtain the masks of various parts of the human body.

📊 Dataset

data_sample

Our dataset has been released here. We provide the download and unzip script in download_dataset.py, please use the following command:

python download_dataset.py

It will prepare the dataset in the folder data/DeepFashion-MultiModal-Parts2Whole, so that you can run our config to train the model or run our dataset file parts2whole/data/ref_trg.py to check our dataset.

Make sure that the dataset is organized as follows:

DeepFashion-MultiModal-Parts2Whole
# Structure signals
|-- densepose
|-- openpose
# Appearance conditions
|-- face
|-- hair_headwear
|-- lower_body_clothes
|-- upper_body_clothes
|-- whole_body_clothes
|-- shoes
# Target images
|-- images
# Caption file
|-- train.jsonl
`-- test.jsonl

This human image dataset comprising about 41,500 reference-target pairs. Each pair in this dataset includes multiple reference images, including pose maps, various aspects of human appearance (e.g., hair, face, clothes, shoes), and a target image featuring the same individual (ID), along with textual captions. Details about the dataset refer to our dataset repo.

Our dataset is post-processed from DeepFashion-Multimodal dataset.

🏋️ Training

If training our parts2whole in a single device, use the following command:

python train.py --config configs/train-sd15.yaml

If training on a DDP environment (assume 8 devices here), run the command:

accelerate launch \
--mixed_precision=fp16 \
--num_processes=$((8*$WORLD_SIZE)) \  # 8 is the number of devices
--num_machines=$WORLD_SIZE \
--multi_gpu \
--machine_rank=$RANK \
train.py --config configs/train-sd15.yaml

In our config file, the batch size per device is set to 8 (which is recommended for a device of 80G memory). If you train on a device of smaller memory, you need to reduce it.

😊 Evaluation

For evaluation, please install additional packages firstly:

pip install git+https://github.com/openai/CLIP.git # for clip
pip install dreamsim # for dreamsim
pip install lpips # for lpips

We provide easy-to-use evaluation scripts in scripts/evals folder. The scripts receive a unified formated data, which is organize as two lists of images as input. Modify the code for loading images as you need. Check our scripts for more details.

🔨 Tools

Real-ESRGAN

To use Real-ESRGAN to restore images, please download RealESRGAN_x4plus.pth into ./pretrained_weights/Real-ESRGAN firstly. Then run command:

python -m scripts.real_esrgan -n RealESRGAN_x4plus -i /path/to/dir -o /path/to/dir --face_enhance

SegFormer

To use segformer to segment human images and obtain hat, hair, face, clothes parts, please run command:

python scripts/segformer_b2_clothes.py --image-path /path/to/image --output-dir /path/to/dir

Labels: 0: "Background", 1: "Hat", 2: "Hair", 3: "Sunglasses", 4: "Upper-clothes", 5: "Skirt", 6: "Pants", 7: "Dress", 8: "Belt", 9: "Left-shoe", 10: "Right-shoe", 11: "Face", 12: "Left-leg", 13: "Right-leg", 14: "Left-arm", 15: "Right-arm", 16: "Bag", 17: "Scarf"

😭 Limitations

At present, the generalization of the training data is average, and the number of women is relatively large, so the generalization of the model needs to be improved, such as stylization, etc. We are working hard to improve the robustness and capabilities of the model, and we also look forward to and welcome contributions/pull requests from the community.

🤝 Acknowledgement

We appreciate the open source of the following projects:

diffusersmagic-animateMoore-AnimateAnyoneDeepFashion-MultiModalReal-ESRGAN

📎 Citation

If you find this repository useful, please consider citing:

@misc{huang2024parts2whole,
  title={From Parts to Whole: A Unified Reference Framework for Controllable Human Image Generation},
  author={Huang, Zehuan and Fan, Hongxing and Wang, Lipeng and Sheng, Lu},
  journal={arXiv preprint arXiv:2404.15267},
  year={2024}
}

About

[Arxiv 2024] From Parts to Whole: A Unified Reference Framework for Controllable Human Image Generation

Topics

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages