From 800a2bba3de0f0cf54ee53f89c9aa910f1c7c4c6 Mon Sep 17 00:00:00 2001 From: Hyunmin-H Date: Thu, 17 Aug 2023 18:37:33 +0000 Subject: [PATCH] =?UTF-8?q?[Refactor]=20upper=20&=20lower=EA=B9=8C?= =?UTF-8?q?=EC=A7=80=20=EB=8F=99=EC=9E=91=20#33?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - main.py의 inference 함수 다시 구현(inference_preprocess 함수) related to : #31 --- backend/app/frontend.py | 30 ++------- backend/app/main.py | 143 ++++++++-------------------------------- 2 files changed, 31 insertions(+), 142 deletions(-) diff --git a/backend/app/frontend.py b/backend/app/frontend.py index e485c28..cecf3e8 100644 --- a/backend/app/frontend.py +++ b/backend/app/frontend.py @@ -56,10 +56,6 @@ def check_modelLoading(): pass return is_modelLoading -def read_image_as_bytes(image_path): - with open(image_path, "rb") as file: - image_data = file.read() - return image_data ## 이미지 리스트에 저장 def append_imgList(uploaded_garment, category): @@ -89,7 +85,6 @@ def show_garments_and_checkboxes(category): for i, filename in enumerate(filenames): im_dir = os.path.join(category_dir, filename) # garment_img = Image.open(im_dir) - # garment_byte = read_image_as_bytes(im_dir) garment_img = gcs.read_image_from_gcs(im_dir) # st.image(garment_img, caption=filename[:-4], width=100) @@ -104,7 +99,6 @@ def show_garments_and_checkboxes(category): filenames_ = [None] filenames_.extend([f[:-4] for f in filenames]) selected_garment = st.selectbox('입을 옷을 선택해주세요.', filenames_, index=0, key=category) - print('selected_garment', selected_garment) im_dir = os.path.join(category_dir, f'{selected_garment}.jpg') garment_byte = gcs.read_image_from_gcs(im_dir) @@ -205,7 +199,6 @@ def main(): selected_byte, selected_upper = show_garments_and_checkboxes(category) if selected_upper : is_selected_upper = True - # files[2] = ('files', f'{selected_upper}.jpg') files[2] = ('files', selected_byte) print('selected_upper', selected_upper) @@ -222,25 +215,21 @@ def main(): selected_byte, selected_lower = show_garments_and_checkboxes(category) if selected_lower : is_selected_lower = True - files[3] = ('files', f'{selected_lower}.jpg') + files[3] = ('files', selected_byte) st.write(' ') st.write(' ') st.markdown("

드레스👗

", unsafe_allow_html=True) category = 'dresses' - uploaded_garment = st.file_uploader("추가할 드레스를 넣어주세요.", type=["jpg", "jpeg", "png"]) if uploaded_garment : append_imgList(uploaded_garment, category) - filenames, selected_dress = show_garments_and_checkboxes(category) + selected_byte, selected_dress = show_garments_and_checkboxes(category) if selected_dress : is_selected_dress = True - files[2] = ('files', f'{selected_dress}.jpg') - print('is_selected_lower', is_selected_lower) - print('is_selected_dress', is_selected_dress) - + files[2] = ('files', selected_byte) with col2: st.markdown("

드레스룸🚪

", unsafe_allow_html=True) @@ -259,11 +248,6 @@ def main(): human_slot.empty() human_slot.image(target_img) - # else : - - # example_img = Image.open('/opt/ml/level3_cv_finalproject-cv-12/backend/app/utils/example.jpg') - # human_slot.image(example_img, width=300, use_column_width=True, caption='Example of target image') - print('start_button', start_button) if start_button and uploaded_target: if is_selected_upper and is_selected_lower : @@ -305,13 +289,7 @@ def main(): empty_slot.empty() empty_slot.markdown("

Here it is !

", unsafe_allow_html=True) - output_ladi_buffer_dir = '/opt/ml/user_db/ladi/buffer' - final_result_dir = output_ladi_buffer_dir - if category =='upper_lower': - final_img = Image.open(os.path.join(final_result_dir, 'lower_body.png')) - else : - # final_img = Image.open(os.path.join(final_result_dir, f'{category}.png')) - final_img = response.content + final_img = response.content st.write(' ') st.write(' ') diff --git a/backend/app/main.py b/backend/app/main.py index 8f7cbf9..0d17eab 100644 --- a/backend/app/main.py +++ b/backend/app/main.py @@ -9,21 +9,21 @@ # scp setting import sys, os sys.path.append('/opt/ml/level3_cv_finalproject-cv-12/model/Self_Correction_Human_Parsing/') -from simple_extractor import main_schp, main_schp_fromImageByte +from simple_extractor import main_schp # openpose sys.path.append('/opt/ml/level3_cv_finalproject-cv-12/model/pytorch_openpose/') -from extract_keypoint import main_openpose, main_openpose_fromImageByte +from extract_keypoint import main_openpose # ladi sys.path.append('/opt/ml/level3_cv_finalproject-cv-12/model/ladi_vton') sys.path.append('/opt/ml/level3_cv_finalproject-cv-12/model/ladi_vton/src') sys.path.append('/opt/ml/level3_cv_finalproject-cv-12/model/ladi_vton/src/utils') -from get_clothing_mask import main_mask, main_mask_fromImageByte +from get_clothing_mask import main_mask -from inference import main_ladi, main_ladi_fromImageByte -from face_cut_and_paste import main_cut_and_paste +from inference import main_ladi +from face_cut_and_paste import cut_and_paste import torch from accelerate import Accelerator @@ -68,21 +68,6 @@ async def add_garment_to_db(files: List[UploadFile] = File(...)): gcs.upload_blob(garment_bytes, os.path.join(user_name, 'input/garment', category, f'{garment_name}')) # garment_image.save(os.path.join(db_dir, 'input/garment', category, f'{garment_name}')) -def read_image_as_bytes(image_path): - with open(image_path, "rb") as file: - image_data = file.read() - return image_data - -@app.get("/get_db/{category}") -async def get_DB(category: str) : - category_dir = os.path.join(db_dir, 'input/garment', category) - garment_db_bytes = {} - for filename in os.listdir(category_dir): - garment_id = filename[:-4] - garment_byte = read_image_as_bytes(os.path.join(category_dir, filename)) - garment_db_bytes[garment_id] = garment_byte - return garment_db_bytes - def load_ladiModels(): pretrained_model_name_or_path = "stabilityai/stable-diffusion-2-inpainting" @@ -133,61 +118,26 @@ async def get_boolean(): global is_modelLoading return {"is_modelLoading": is_modelLoading} -def inference_allModels(target_bytes, garment_bytes, category, db_dir): - - input_dir = os.path.join(db_dir, 'input') +def inference_preprocess(target_bytes, garment_bytes, garment_lower_bytes=None): # schp - (1024, 784), (512, 384) - target_buffer_dir = os.path.join(input_dir, 'buffer/target') - # main_schp(target_buffer_dir) - schp_img = main_schp_fromImageByte(target_bytes) - schp_img.save('./schp.png') - + schp_img = main_schp(target_bytes) + # openpose - output_openpose_buffer_dir = os.path.join(db_dir, 'openpose/buffer') - os.makedirs(output_openpose_buffer_dir, exist_ok=True) - # main_openpose(target_buffer_dir, output_openpose_buffer_dir) - keypoint_dict = main_openpose_fromImageByte(target_bytes) - gcs.upload_dict_as_json_to_gcs(keypoint_dict, os.path.join(db_dir, 'openpose/buffer/target.json')) - - # /opt/ml/user_db/mask/buffer - # mask - garment_dir = os.path.join(input_dir, 'buffer/garment') - output_mask_dir = os.path.join(db_dir, 'mask/buffer') - os.makedirs(output_mask_dir, exist_ok=True) - # main_mask(category, garment_dir, output_mask_dir) + keypoint_dict = main_openpose(target_bytes) ## garment_mask 형식 - Image - garment_mask = main_mask_fromImageByte(garment_bytes) - - garment_mask.save('./garment_mask.jpg') - - # ladi-vton - output_ladi_buffer_dir = os.path.join(db_dir, 'ladi/buffer') - os.makedirs(output_ladi_buffer_dir, exist_ok=True) - - # main_ladi(category, db_dir, output_ladi_buffer_dir, ladi_models) - finalResult_img = main_ladi_fromImageByte(category, target_bytes, schp_img, keypoint_dict, garment_bytes, garment_mask, ladi_models) - finalResult_img = main_cut_and_paste(category, target_bytes, finalResult_img, schp_img) - return finalResult_img - -def inference_ladi(category, db_dir, target_name='target.jpg'): - input_dir = os.path.join(db_dir, 'input') - garment_dir = os.path.join(input_dir, 'buffer/garment') - output_mask_dir = os.path.join(db_dir, 'mask/buffer') - main_mask(category, garment_dir, output_mask_dir) - - # ladi-vton - output_ladi_buffer_dir = os.path.join(db_dir, 'ladi/buffer') - os.makedirs(output_ladi_buffer_dir, exist_ok=True) - - main_ladi(category, db_dir, output_ladi_buffer_dir, ladi_models, target_name) - main_cut_and_paste(category, db_dir, target_name) + garment_mask = main_mask(garment_bytes) + if garment_lower_bytes is None : + return schp_img, keypoint_dict, garment_mask + else : + garment_lower_mask = main_mask(garment_lower_bytes) + + return schp_img, keypoint_dict, garment_mask, garment_lower_mask # post!! @app.post("/order", description="주문을 요청합니다") async def make_order(files: List[UploadFile] = File(...)): - # input_dir = '/opt/ml/user_db/input/' input_dir = f'{user_name}/input' # category : files[0], target:files[1], garment:files[2] @@ -198,67 +148,28 @@ async def make_order(files: List[UploadFile] = File(...)): ## category가 upper & lower일 경우 target_bytes = await files[1].read() - target_image = Image.open(io.BytesIO(target_bytes)) - target_image = target_image.convert("RGB") - - os.makedirs(f'{input_dir}/buffer', exist_ok=True) - # target_image.save(f'{input_dir}/buffer/target/target.jpg') - gcs.upload_blob(target_bytes, f'{input_dir}/buffer/target/target.jpg') if category == 'upper_lower': - # garment_upper_bytes = await files[2].read() - # garment_lower_bytes = await files[3].read() - - # garment_upper_image = Image.open(io.BytesIO(garment_upper_bytes)) - # garment_upper_image = garment_upper_image.convert("RGB") - # garment_lower_image = Image.open(io.BytesIO(garment_lower_bytes)) - # garment_lower_image = garment_lower_image.convert("RGB") - - # # garment_upper_image.save(f'{input_dir}/upper_body.jpg') - # garment_upper_image.save(f'{input_dir}/buffer/garment/upper_body.jpg') - # # garment_lower_image.save(f'{input_dir}/lower_body.jpg') - # garment_lower_image.save(f'{input_dir}/buffer/garment/lower_body.jpg') - - - ## string으로 전송됐을 때(filename) - string_upper_bytes = await files[2].read() - string_lower_bytes = await files[3].read() - string_io_upper = io.BytesIO(string_upper_bytes) - string_io_lower = io.BytesIO(string_lower_bytes) - filename_upper = string_io_upper.read().decode('utf-8') - filename_lower = string_io_lower.read().decode('utf-8') - - garment_image_upper = Image.open(os.path.join(db_dir, 'input/garment', 'upper_body', filename_upper)) - garment_image_lower = Image.open(os.path.join(db_dir, 'input/garment', 'lower_body', filename_lower)) - garment_image_upper.save(f'{input_dir}/buffer/garment/upper_body.jpg') - garment_image_lower.save(f'{input_dir}/buffer/garment/lower_body.jpg') - + garment_upper_bytes = await files[2].read() + garment_lower_bytes = await files[3].read() - finalResult_img = inference_allModels('upper_body', db_dir) - shutil.copy(os.path.join(db_dir, 'ladi/buffer', 'upper_body.png'), f'{input_dir}/buffer/target/upper_body.jpg') - inference_ladi('lower_body', db_dir, target_name='upper_body.jpg') + schp_img, keypoint_dict, garment_upper_mask, garment_lower_mask = inference_preprocess(target_bytes, garment_upper_bytes, garment_lower_bytes) + ladi_img = main_ladi('upper_body', target_bytes, schp_img, keypoint_dict, garment_upper_bytes, garment_upper_mask, ladi_models) + ladi_bytes = PIL2Byte(ladi_img) + ladi_img = main_ladi('lower_body', ladi_bytes, schp_img, keypoint_dict, garment_lower_bytes, garment_lower_mask, ladi_models) + finalResult_img = cut_and_paste(target_bytes, ladi_img, schp_img) else : ## file로 전송됐을 때 - garment_bytes = await files[2].read() - garment_image = Image.open(io.BytesIO(garment_bytes)) - garment_image = garment_image.convert("RGB") - - ## string으로 전송됐을 때(filename) - # byte_string = await files[2].read() - # string_io = io.BytesIO(byte_string) - # filename = string_io.read().decode('utf-8') - - # garment_image = Image.open(os.path.join(db_dir, 'input/garment', category, filename)) - # garment_image.save(f'{input_dir}/buffer/garment/{category}.jpg') - gcs.upload_blob(garment_bytes, f'{input_dir}/buffer/garment/{category}.jpg') - - finalResult_img = inference_allModels(target_bytes, garment_bytes, category, user_name) + schp_img, keypoint_dict, garment_mask = inference_preprocess(target_bytes, garment_bytes) + ladi_img = main_ladi(category, target_bytes, schp_img, keypoint_dict, garment_bytes, garment_mask, ladi_models) + finalResult_img = cut_and_paste(target_bytes, ladi_img, schp_img) finalResult_bytes = PIL2Byte(finalResult_img) gcs.upload_blob(finalResult_bytes, f'{input_dir}/ladi/buffer/final.jpg') + return StreamingResponse(io.BytesIO(finalResult_bytes), media_type="image/jpg") \ No newline at end of file