diff --git a/.DS_Store b/.DS_Store
index e955512..4e39217 100644
Binary files a/.DS_Store and b/.DS_Store differ
diff --git a/README.md b/README.md
index 37c4684..031b496 100644
--- a/README.md
+++ b/README.md
@@ -17,7 +17,178 @@ Official implementation of DCT-Net for Full-body Portrait Stylization.
## Demo
-![demo_vid](assets/demo.gif)
+![demo](assets/demo.gif)
+
+
+## News
+
+(2023-03-14) The training guidance has been released, train DCT-Net with your own style data.
+
+(2023-02-20) Two new style pre-trained models (design, illustration) trained with combined DCT-Net and Stable-Diffusion are provided. The training guidance will be released soon.
+
+(2022-10-09) The multi-style pre-trained models (3d, handdrawn, sketch, artstyle) and usage are available now.
+
+(2022-08-08) The pertained model and infer code of 'anime' style is available now. More styles coming soon.
+
+(2022-08-08) cartoon function can be directly call from pythonSDK.
+
+(2022-07-07) The paper is available now at arxiv(https://arxiv.org/abs/2207.02426).
+
+
+## Web Demo
+- Integrated into [Colab notebook](https://colab.research.google.com/github/menyifang/DCT-Net/blob/main/notebooks/inference.ipynb). Try out the colab demo.
+
+- Integrated into [Huggingface Spaces 🤗](https://huggingface.co/spaces) using [Gradio](https://github.com/gradio-app/gradio). Try out the Web Demo [![Hugging Face Spaces](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue)](https://huggingface.co/spaces/SIGGRAPH2022/DCT-Net)
+
+- [Chinese version] Integrated into [ModelScope](https://modelscope.cn/#/models). Try out the Web Demo [![ModelScope Spaces](
+https://img.shields.io/badge/ModelScope-Spaces-blue)](https://modelscope.cn/#/models/damo/cv_unet_person-image-cartoon_compound-models/summary)
+
+## Requirements
+* python 3
+* tensorflow (>=1.14, training only support tf1.x)
+* easydict
+* numpy
+* both CPU/GPU are supported
+
+
+## Quick Start
+
+
+
+```bash
+git clone https://github.com/menyifang/DCT-Net.git
+cd DCT-Net
+
+```
+
+### Installation
+```bash
+conda create -n dctnet python=3.7
+conda activate dctnet
+pip install --upgrade tensorflow-gpu==1.15 # GPU support, use tensorflow for CPU only
+pip install "modelscope[cv]==1.3.2" -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html
+pip install "modelscope[multi-modal]==1.3.2" -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html
+```
+
+### Downloads
+
+| [](https://modelscope.cn/models/damo/cv_unet_person-image-cartoon_compound-models/summary) | [](https://modelscope.cn/models/damo/cv_unet_person-image-cartoon-3d_compound-models/summary) | [](https://modelscope.cn/models/damo/cv_unet_person-image-cartoon-handdrawn_compound-models/summary)| [](https://modelscope.cn/models/damo/cv_unet_person-image-cartoon-sketch_compound-models/summary)| [](https://modelscope.cn/models/damo/cv_unet_person-image-cartoon-artstyle_compound-models/summary)|
+|:--:|:--:|:--:|:--:|:--:|
+| [anime](https://modelscope.cn/models/damo/cv_unet_person-image-cartoon_compound-models/summary) | [3d](https://modelscope.cn/models/damo/cv_unet_person-image-cartoon-3d_compound-models/summary) | [handdrawn](https://modelscope.cn/models/damo/cv_unet_person-image-cartoon-handdrawn_compound-models/summary) | [sketch](https://modelscope.cn/models/damo/cv_unet_person-image-cartoon-sketch_compound-models/summary) | [artstyle](https://modelscope.cn/models/damo/cv_unet_person-image-cartoon-artstyle_compound-models/summary) |
+
+| [](https://modelscope.cn/models/damo/cv_unet_person-image-cartoon-sd-design_compound-models/summary) | [](https://modelscope.cn/models/damo/cv_unet_person-image-cartoon-sd-illustration_compound-models/summary) |
+|:--:|:--:|
+| [design](https://modelscope.cn/models/damo/cv_unet_person-image-cartoon-sd-design_compound-models/summary) | [illustration](https://modelscope.cn/models/damo/cv_unet_person-image-cartoon-sd-illustration_compound-models/summary)
+
+Pre-trained models in different styles can be downloaded by
+```bash
+python download.py
+```
+
+### Inference
+
+- from python SDK
+```bash
+python run_sdk.py
+```
+
+- from source code
+```bash
+python run.py
+```
+
+### Video cartoonization
+
+![demo_vid](assets/video.gif)
+
+video can be directly processed as image sequences, style choice [option: anime, 3d, handdrawn, sketch, artstyle, sd-design, sd-illustration]
+
+```bash
+python run_vid.py --style anime
+```
+
+
+## Training
+
+### Data preparation
+```
+face_photo: face dataset such as [FFHQ](https://github.com/NVlabs/ffhq-dataset) or other collected real faces.
+face_cartoon: 100-300 cartoon face images in a specific style, which can be self-collected or synthsized with generative models.
+```
+Due to the copyrighe issues, we can not provide collected cartoon exemplar for training. You can produce cartoon exemplars with the style-finetuned Stable-Diffusion (SD) models, which can be downloaded from modelscope or huggingface hubs.
+
+The effects of some style-finetune SD models are as follows:
+
+| [](https://modelscope.cn/models/damo/cv_cartoon_stable_diffusion_design/summary) | [](https://modelscope.cn/models/damo/cv_cartoon_stable_diffusion_watercolor) | [](https://modelscope.cn/models/damo/cv_cartoon_stable_diffusion_illustration/summary)| [](https://modelscope.cn/models/damo/cv_cartoon_stable_diffusion_clipart/summary)| [](https://modelscope.cn/models/damo/cv_cartoon_stable_diffusion_flat/summary)|
+|:--:|:--:|:--:|:--:|:--:|
+| [design](https://modelscope.cn/models/damo/cv_cartoon_stable_diffusion_design/summary) | [watercolor](https://modelscope.cn/models/damo/cv_cartoon_stable_diffusion_watercolor/summary) | [illustration](https://modelscope.cn/models/damo/cv_cartoon_stable_diffusion_illustration/summary) | [clipart](https://modelscope.cn/models/damo/cv_cartoon_stable_diffusion_clipart/summary) | [flat](https://modelscope.cn/models/damo/cv_cartoon_stable_diffusion_flat/summary) |
+
+- Generate stylized data, style choice [option: clipart, design, illustration, watercolor, flat]
+```bash
+python generate_data.py --style clipart
+```
+
+- preprocess
+
+extract aligned faces from raw style images:
+```bash
+python extract_align_faces.py --src_dir 'data/raw_style_data'
+```
+
+- train content calibration network
+
+install environment required by (stylegan2-pytorch)[https://github.com/rosinality/stylegan2-pytorch]
+```bash
+cd source/stylegan2
+python prepare_data.py '../../data/face_cartoon' --size 256 --out '../../data/stylegan2/traindata'
+CUDA_VISIBLE_DEVICES=0 python train_condition.py --name 'ffhq_style_s256' --path '../../data/stylegan2/traindata' --config config/conf_server_train_condition_shell.json
+```
+
+after training, generated content calibrated samples via:
+```bash
+python style_blend.py --name 'ffhq_style_s256'
+python generate_blendmodel.py --name 'ffhq_style_s256' --save_dir '../../data/face_cartoon/syn_style_faces'
+```
+
+- geometry calibration
+
+run geometry calibration for both photo and cartoon:
+```bash
+cd source
+python image_flip_agument_parallel.py --data_dir '../data/face_cartoon'
+python image_scale_agument_parallel_flat.py --data_dir '../data/face_cartoon'
+python image_rotation_agument_parallel_flat.py --data_dir '../data/face_cartoon'
+```
+
+- train texture translator
+
+The dataset structure is recommended as:
+```
++—data
+| +—face_photo
+| +—face_cartoon
+```
+
+resume training from the pretrai# DCT-Net: Domain-Calibrated Translation for Portrait Stylization
+
+### [Project page](https://menyifang.github.io/projects/DCTNet/DCTNet.html) | [Video](https://www.youtube.com/watch?v=Y8BrfOjXYQM) | [Paper](https://arxiv.org/abs/2207.02426)
+
+Official implementation of DCT-Net for Full-body Portrait Stylization.
+
+
+> [**DCT-Net: Domain-Calibrated Translation for Portrait Stylization**](arxiv_url_coming_soon),
+> [Yifang Men](https://menyifang.github.io/)1, Yuan Yao1, Miaomiao Cui1, [Zhouhui Lian](https://www.icst.pku.edu.cn/zlian/)2, Xuansong Xie1,
+> _1[DAMO Academy, Alibaba Group](https://damo.alibaba.com), Beijing, China_
+> _2[Wangxuan Institute of Computer Technology, Peking University](https://www.icst.pku.edu.cn/), China_
+> In: SIGGRAPH 2022 (**TOG**)
+> *[arXiv preprint](https://arxiv.org/abs/2207.02426)*
+
+
+[![Hugging Face Spaces](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue)](https://huggingface.co/spaces/SIGGRAPH2022/DCT-Net)
+
+
+## Demo
+![demo](assets/demo.gif)
## News
@@ -44,7 +215,6 @@ https://img.shields.io/badge/ModelScope-Spaces-blue)](https://modelscope.cn/#/mo
* python 3
* tensorflow (>=1.14)
* easydict
-* imageio[ffmpeg]
* numpy
* both CPU/GPU are supported
@@ -63,9 +233,11 @@ cd DCT-Net
```bash
conda create -n dctnet python=3.7
conda activate dctnet
+pip install numpy==1.18.5
+pip install torch==1.7.1+cu101 torchvision==0.8.2+cu101 torchaudio==0.7.2 -f https://download.pytorch.org/whl/torch_stable.html
pip install --upgrade tensorflow-gpu==1.15 # GPU support, use tensorflow for CPU only
-pip install "modelscope[cv]==1.3.0" -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html
-pip install "modelscope[multi-modal]==1.3.0" -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html
+pip install "modelscope[cv]==1.3.2" -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html
+pip install "modelscope[multi-modal]==1.3.2" -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html
```
### Downloads
@@ -102,11 +274,12 @@ python run.py
video can be directly processed as image sequences, style choice [option: anime, 3d, handdrawn, sketch, artstyle, sd-design, sd-illustration]
```bash
-python run_vid.py --video_path input.mp4 --save_path res.mp4 --style anime
+python run_vid.py --style anime
```
## Training
+
### Data preparation
```
@@ -126,16 +299,54 @@ The effects of some style-finetune SD models are as follows:
python generate_data.py --style clipart
```
+- preprocess
-### Train content calibration network
-To-be-added
+extract aligned faces from raw style images:
+```bash
+python extract_align_faces.py --src_dir 'data/raw_style_data'
+```
-### Geometry calibration
-To-be-added
+- train content calibration network
-### Train texture translator
-To-be-added
+install environment required by (stylegan2-pytorch)[https://github.com/rosinality/stylegan2-pytorch]
+```bash
+cd source/stylegan2
+python prepare_data.py '../../data/face_cartoon' --size 256 --out '../../data/stylegan2/traindata'
+CUDA_VISIBLE_DEVICES=0 python train_condition.py --name 'ffhq_style_s256' --path '../../data/stylegan2/traindata' --config config/conf_server_train_condition_shell.json
+```
+
+after training, generated content calibrated samples via:
+```bash
+python style_blend.py --name 'ffhq_style_s256'
+python generate_blendmodel.py --name 'ffhq_style_s256' --save_dir '../../data/face_cartoon/syn_style_faces'
+```
+
+- geometry calibration
+
+run geometry calibration for both photo and cartoon:
+```bash
+cd source
+python image_flip_agument_parallel.py --data_dir '../data/face_cartoon'
+python image_scale_agument_parallel_flat.py --data_dir '../data/face_cartoon'
+python image_rotation_agument_parallel_flat.py --data_dir '../data/face_cartoon'
+```
+
+- train texture translator
+
+The dataset structure is recommended as:
+```
++—data
+| +—face_photo
+| +—face_cartoon
+```
+
+resume training from pretrained model in similar style:
+style can be chosen from 'anime, 3d, handdrawn, sketch, artstyle, sd-design, sd-illustration'
+
+```bash
+python train_localtoon.py --data_dir PATH_TO_YOU_DATA --work_dir PATH_SAVE --style anime
+```
@@ -166,7 +377,3 @@ If you find this code useful for your research, please use the following BibTeX
-
-
-
-
diff --git a/extract_align_faces.py b/extract_align_faces.py
new file mode 100644
index 0000000..1620f14
--- /dev/null
+++ b/extract_align_faces.py
@@ -0,0 +1,162 @@
+import cv2
+import os
+import numpy as np
+import argparse
+from source.facelib.facer import FaceAna
+import source.utils as utils
+from source.mtcnn_pytorch.src.align_trans import warp_and_crop_face, get_reference_facial_points
+from modelscope.hub.snapshot_download import snapshot_download
+
+class FaceProcesser:
+ def __init__(self, dataroot, crop_size = 256, max_face = 1):
+ self.max_face = max_face
+ self.crop_size = crop_size
+ self.facer = FaceAna(dataroot)
+
+ def filter_face(self, lm, crop_size):
+ a = max(lm[:, 0])-min(lm[:, 0])
+ b = max(lm[:, 1])-min(lm[:, 1])
+ # print("a:%d, b:%d"%(a,b))
+ if max(a, b)0:
+ continue
+
+ if self.filter_face(landmark, self.crop_size)==0:
+ print("filtered!")
+ continue
+
+ f5p = utils.get_f5p(landmark, img_bgr)
+ # face alignment
+ warped_face, _ = warp_and_crop_face(
+ img_bgr,
+ f5p,
+ ratio=0.75,
+ reference_pts=get_reference_facial_points(default_square=True),
+ crop_size=(self.crop_size, self.crop_size),
+ return_trans_inv=True)
+
+ warped_faces.append(warped_face)
+ i = i+1
+
+
+ return warped_faces
+
+
+
+
+if __name__ == "__main__":
+
+
+ parser = argparse.ArgumentParser(description="process remove bg result")
+ parser.add_argument("--src_dir", type=str, default='', help="Path to src images.")
+ parser.add_argument("--save_dir", type=str, default='', help="Path to save images.")
+ parser.add_argument("--crop_size", type=int, default=256)
+ parser.add_argument("--max_face", type=int, default=1)
+ parser.add_argument("--overwrite", type=int, default=1)
+ args = parser.parse_args()
+ args.save_dir = os.path.dirname(args.src_dir) + '/face_cartoon/raw_style_faces'
+
+ crop_size = args.crop_size
+ max_face = args.max_face
+ overwrite = args.overwrite
+
+ # model_dir = snapshot_download('damo/cv_unet_person-image-cartoon_compound-models', cache_dir='.')
+ # print('model assets saved to %s'%model_dir)
+ model_dir = 'damo/cv_unet_person-image-cartoon_compound-models'
+
+ processer = FaceProcesser(dataroot=model_dir,crop_size=crop_size, max_face =max_face)
+
+ src_dir = args.src_dir
+ save_dir = args.save_dir
+
+ # print('Step: start to extract aligned faces ... ...')
+
+ print('src_dir:%s'% src_dir)
+ print('save_dir:%s'% save_dir)
+
+ if not os.path.exists(save_dir):
+ os.makedirs(save_dir)
+
+ paths = utils.all_file(src_dir)
+ print('to process %d images'% len(paths))
+
+ for path in sorted(paths):
+ dirname = path[len(src_dir)+1:].split('/')[0]
+
+ outpath = save_dir + path[len(src_dir):]
+ if not overwrite:
+ if os.path.exists(outpath):
+ continue
+
+ sub_dir = os.path.dirname(outpath)
+ # print(sub_dir)
+ if not os.path.exists(sub_dir):
+ os.makedirs(sub_dir, exist_ok=True)
+
+ imgb = None
+ imgc = None
+ img = cv2.imread(path, -1)
+ if img is None:
+ continue
+
+ if len(img.shape)==2:
+ img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
+
+ # print(img.shape)
+ h,w,c = img.shape
+ if h<256 or w<256:
+ continue
+ imgs = []
+
+ # if need resize, resize here
+ img_h, img_w, _ = img.shape
+ warped_faces = processer.process(img)
+ if warped_faces is None:
+ continue
+ # ### only for anime faces, single, not detect face
+ # warped_face = imga
+
+ i=0
+ for res in warped_faces:
+ # filter small faces
+ h, w, c = res.shape
+ if h < 256 or w < 256:
+ continue
+ outpath = os.path.join(os.path.dirname(outpath), os.path.basename(outpath)[:-4] + '_' + str(i) + '.png')
+
+ cv2.imwrite(outpath, res)
+ print('save %s' % outpath)
+ i = i+1
+
+
+
+
+
+
+
diff --git a/notebooks/fastTrain.ipynb b/notebooks/fastTrain.ipynb
new file mode 100644
index 0000000..57905a2
--- /dev/null
+++ b/notebooks/fastTrain.ipynb
@@ -0,0 +1,215 @@
+{
+ "nbformat": 4,
+ "nbformat_minor": 0,
+ "metadata": {
+ "colab": {
+ "provenance": [],
+ "collapsed_sections": []
+ },
+ "kernelspec": {
+ "name": "python3",
+ "display_name": "Python 3"
+ },
+ "language_info": {
+ "name": "python"
+ }
+ },
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "source": [
+ "x[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/menyifang/DCT-Net/blob/main/notebooks/inference.ipynb)"
+ ],
+ "metadata": {
+ "id": "D2MFmZtpVEp_"
+ }
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "## Step 1: Installation"
+ ],
+ "metadata": {
+ "id": "zoNN1PYUOUgU"
+ }
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "7Tv2ZUgrGO6Z"
+ },
+ "outputs": [],
+ "source": [
+ "!pip install numpy==1.18.5\n",
+ "!pip install torch==1.7.1+cu101 torchvision==0.8.2+cu101 torchaudio==0.7.2 -f https://download.pytorch.org/whl/torch_stable.html\n",
+ "!pip install --upgrade tensorflow-gpu==1.15 # GPU support, use tensorflow for CPU only\n",
+ "!pip install \"modelscope[cv]==1.3.0\" -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html\n",
+ "!pip install \"modelscope[multi-modal]==1.3.0\" -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "## Step 2: Data preparation\n",
+ "\n",
+ "style-pretrained diffusion model can be used to produce raw style data (celebrity name -> image)"
+ ],
+ "metadata": {
+ "id": "kTiXDn7_OTMy"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "import cv2\n",
+ "from modelscope.pipelines import pipeline\n",
+ "from modelscope.utils.constant import Tasks\n",
+ "\n",
+ "pipe = pipeline(Tasks.text_to_image_synthesis, model='damo/cv_cartoon_stable_diffusion_clipart', model_revision='v1.0.0')\n",
+ "from diffusers.schedulers import EulerAncestralDiscreteScheduler\n",
+ "pipe.pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe.pipeline.scheduler.config)\n",
+ "output = pipe({'text': 'archer style, a portrait painting of Johnny Depp'})\n",
+ "cv2.imwrite('result.png', output['output_imgs'][0])\n",
+ "print('Image saved to result.png')\n",
+ "\n",
+ "print('finished!')"
+ ],
+ "metadata": {
+ "id": "lL2JQBL5Qqjn"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "## Step 3: Training\n",
+ "\n",
+ "fast training of local texture translator with diverse training data:"
+ ],
+ "metadata": {
+ "id": "fShi4lT-ODdE"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "import os\n",
+ "import unittest\n",
+ "import cv2\n",
+ "from modelscope.exporters.cv import CartoonTranslationExporter\n",
+ "from modelscope.msdatasets import MsDataset\n",
+ "from modelscope.outputs import OutputKeys\n",
+ "from modelscope.pipelines import pipeline\n",
+ "from modelscope.pipelines.base import Pipeline\n",
+ "from modelscope.trainers.cv import CartoonTranslationTrainer\n",
+ "from modelscope.utils.constant import Tasks\n",
+ "from modelscope.utils.test_utils import test_level\n",
+ "\n",
+ "model_id = 'damo/cv_unet_person-image-cartoon_compound-models'\n",
+ "data_dir = MsDataset.load(\n",
+ " 'dctnet_train_clipart_mini_ms',\n",
+ " namespace='menyifang',\n",
+ " split='train').config_kwargs['split_config']['train']\n",
+ "\n",
+ "# replace data_dir with your own training data\n",
+ "data_photo = os.path.join(data_dir, 'face_photo')\n",
+ "data_cartoon = os.path.join(data_dir, 'face_cartoon')\n",
+ "work_dir = 'exp_localtoon'\n",
+ "# recomand 30000 steps for real training\n",
+ "max_steps = 10\n",
+ "trainer = CartoonTranslationTrainer(\n",
+ " model=model_id,\n",
+ " work_dir=work_dir,\n",
+ " photo=data_photo,\n",
+ " cartoon=data_cartoon,\n",
+ " max_steps=max_steps)\n",
+ "trainer.train()\n"
+ ],
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 1000
+ },
+ "id": "k7Llm0SWMLKr",
+ "outputId": "20b665ce-653e-48a9-c489-d66e93e1907f"
+ },
+ "execution_count": 21,
+ "outputs": [
+ {
+ "output_type": "display_data",
+ "data": {
+ "text/plain": [
+ "