Skip to content

Commit

Permalink
Merge pull request #36 from kookmin-sw/feat/flask_data_extract
Browse files Browse the repository at this point in the history
Feat/flask data extract
  • Loading branch information
lkl4502 authored May 3, 2024
2 parents 720fd89 + 4bd91fb commit 76fcb58
Show file tree
Hide file tree
Showing 9 changed files with 147 additions and 10 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,4 @@ personal_color_dataset/valid/img*.jpg
gan/pretrained_models
gan/output
predict_image/*

21 changes: 19 additions & 2 deletions Color_extract/color.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
sys.path.append(os.path.dirname(os.path.abspath(os.path.dirname(__file__))))
from Skin_detect.skin_detect_v2 import *
from image_processing.gamma_correction import *
from PC_model.utils import draw_3d_rgb

def extract_points(mask, img):
points = np.argwhere(mask)
Expand Down Expand Up @@ -256,7 +257,9 @@ def extract_high_rank(rgb_codes, color_area, percent):
# # file 이름 넣기
# data['filename'][i] = name

def total_data_extract(path):
def total_data_extract(path, save_image):
folder_path = os.path.dirname(path)

data = {'Red' : 0, 'Green' : 0, 'Blue' : 0,
'Hue' : 0, 'Saturation' : 0, 'Value' : 0,
'Y' : 0, 'Cr' : 0, 'Cb' : 0,
Expand All @@ -278,6 +281,10 @@ def total_data_extract(path):
# 이진 마스크로 변환
binary_mask = (face_nose_mask >= 0.5).astype(int)

if save_image:
if not os.path.exists(os.path.join(folder_path, "binary_mask.jpg")):
cv2.imwrite(os.path.join(folder_path, "binary_mask.jpg"), binary_mask * 255)

# image load
image = cv2.imread(path)
image = gamma_correction(image, 0.8)
Expand All @@ -286,6 +293,12 @@ def total_data_extract(path):
ycrcb_image = cv2.cvtColor(image, cv2.COLOR_BGR2YCrCb)
lab_image = cv2.cvtColor(image, cv2.COLOR_BGR2LAB)

if save_image:
if not os.path.exists(os.path.join(folder_path, "face_nose_img.jpg")):
face_nose_img = np.zeros_like(image)
face_nose_img[binary_mask == 1] = image[binary_mask == 1]
cv2.imwrite(os.path.join(folder_path, "face_nose_img.jpg"), face_nose_img)

#RGB
rgb = extract_points(binary_mask, rgb_image)
rgb_average = rgb.mean(axis=0).round()
Expand Down Expand Up @@ -348,4 +361,8 @@ def total_data_extract(path):
data['New Green'] = new_rgb_average[1]
data['New Blue'] = new_rgb_average[2]

return data
if save_image:
draw_3d_rgb(new_rgb_codes, new_rgb_codes/255.0, folder_path)

return data

18 changes: 18 additions & 0 deletions PC_model/utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
from sklearn.metrics import confusion_matrix, accuracy_score, precision_score, recall_score
from sklearn.metrics import confusion_matrix, f1_score
import seaborn as sns
import matplotlib
matplotlib.use("agg")
import matplotlib.pyplot as plt
import os


def get_evaluation(y_test, y_pred):
Expand All @@ -28,3 +31,18 @@ def feature_plot(data, label_name):

def heatmap_plot(data, number = True):
sns.heatmap(data, annot=number, fmt=".2f")

# 3d plot을 이용하여 데이터 분포 확인
def draw_3d_rgb(rgb_codes, colors, path = None):
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
ax.scatter(rgb_codes[:, 0], rgb_codes[:, 1], rgb_codes[:, 2], c=colors, marker='o')
ax.set_xlabel('Red')
ax.set_ylabel('Green')
ax.set_zlabel('Blue')
if path is None:
plt.show()
else:
if not os.path.exists(path + "/rgb_3d_plot.jpg"):
plt.savefig(path + "/rgb_3d_plot.jpg", format='jpeg')

1 change: 1 addition & 0 deletions backend/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ dependencies {
implementation 'org.springframework.boot:spring-boot-starter-thymeleaf'
implementation 'org.springframework.boot:spring-boot-starter-web'
implementation 'org.springframework.boot:spring-boot-starter-data-jpa'
implementation 'com.google.code.gson:gson:2.8.6'
compileOnly 'org.projectlombok:lombok'
compileOnly'org.springframework.boot:spring-boot-devtools'
developmentOnly 'org.springframework.boot:spring-boot-devtools'
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@

import org.springframework.boot.SpringApplication;
import org.springframework.boot.autoconfigure.SpringBootApplication;
import org.springframework.boot.autoconfigure.jdbc.DataSourceAutoConfiguration;

@SpringBootApplication
@SpringBootApplication(exclude = DataSourceAutoConfiguration.class)
public class BackendApplication {
public static void main(String[] args) {
SpringApplication.run(BackendApplication.class, args);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
package org.capstone2024.onlyu.controller;

import com.google.gson.JsonObject;
import lombok.extern.slf4j.Slf4j;
import org.capstone2024.onlyu.service.FlaskService;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.http.*;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RequestMethod;
import org.springframework.web.bind.annotation.RequestParam;
import org.springframework.web.bind.annotation.RestController;
import org.springframework.web.client.RestTemplate;
import org.springframework.web.multipart.MultipartFile;

@RestController
@Slf4j
public class FlaskController {
private final FlaskService flaskService;
@Autowired
public FlaskController(FlaskService flaskService){
this.flaskService = flaskService;
}

// 퍼스널 컬러 예측과 피부형 검출
@RequestMapping(value = "/start", method = RequestMethod.POST)
public String start(@RequestParam("email") String email, @RequestParam("gender") String gender,
@RequestParam("image") MultipartFile multipartFile){
String predict_color_res = flaskService.predict_color_flask(multipartFile);
String predict_shape_res = flaskService.predict_shape_flask();

JsonObject obj = new JsonObject();
obj.addProperty("predictColorRes", predict_color_res);
obj.addProperty("predictShapeRes", predict_shape_res);
return obj.toString();
}

// @RequestMapping(value = "/gan_image", method = RequestMethod.GET)
// public String gan_image(){
// Boolean res = flaskService.gan_image_flask();
// JsonObject obj = new JsonObject();
// obj.addProperty("success", res);
// return obj.toString();
// }

}
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,20 @@
import org.springframework.stereotype.Controller;
import org.springframework.ui.Model;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.PostMapping;
import org.springframework.web.bind.annotation.ResponseBody;

@Controller
public class TestController {
@GetMapping("/")
@PostMapping("/")
@ResponseBody
public String test(Model model){
// 이미지, 성별, 생성형 이미지 유무, 이메일
// flask api인 post 요청으로 퍼스널 컬러 검사 결과 요청
// get요청으로 predict_shape()해서 얼굴형 받기
return "Only You~~";
}

// 하나 더 만들어서 생성형 이미지 유무에 맞게 요청 보내기
// 요거는 front에서 구분해서 보내주면 됨. 대신 -> 로딩 화면 X
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
package org.capstone2024.onlyu.service;

import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.http.HttpEntity;
import org.springframework.http.HttpHeaders;
import org.springframework.http.MediaType;
import org.springframework.stereotype.Service;
import org.springframework.transaction.annotation.Transactional;
import org.springframework.util.LinkedMultiValueMap;
import org.springframework.util.MultiValueMap;
import org.springframework.web.client.RestTemplate;
import org.springframework.web.multipart.MultipartFile;


@Service
@RequiredArgsConstructor
@Slf4j
public class FlaskService {

@Transactional
public String predict_color_flask(MultipartFile image){
RestTemplate restTemplate = new RestTemplate();
HttpHeaders headers = new HttpHeaders();
headers.setContentType(MediaType.MULTIPART_FORM_DATA);
String url = "http://127.0.0.1:5050/predict_color";

MultiValueMap<String, Object> body = new LinkedMultiValueMap<>();
body.add("image", image.getResource());
HttpEntity<MultiValueMap<String, Object>> requestEntity = new HttpEntity<>(body, headers);
return restTemplate.postForObject(url, requestEntity, String.class);
}

@Transactional
public String predict_shape_flask(){
RestTemplate restTemplate = new RestTemplate();
String url = "http://127.0.0.1:5050/predict_shape";

return restTemplate.getForObject(url, String.class);
}
}
18 changes: 12 additions & 6 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@
from PC_model.pc_model import PersonalColorModel
from image_processing.gamma_correction import gamma_correction

from sklearn.preprocessing import StandardScaler

import joblib
import os
import pandas as pd
Expand Down Expand Up @@ -33,15 +31,22 @@ def predict_color():
global features, pc_model, ss, current_image_path
# 이미지 저장
f = request.files['image']
f_path = os.path.join(image_path, f.filename)
filename = f.filename

type = f.filename[f.filename.rfind("."):]

folder_path = os.path.join(image_path, f.filename.replace(".", "_"))
f_path = os.path.join(folder_path, "origin_img" + type)

if not os.path.exists(folder_path):
os.makedirs(folder_path)

if not os.path.exists(f_path):
f.save(f_path)

current_image_path = f_path[:]

# 데이터 추출
data = total_data_extract(f_path)
data = total_data_extract(f_path, True)

# DataFrame으로 변환
df = pd.DataFrame(data, index = [0])
Expand All @@ -54,6 +59,7 @@ def predict_color():

# 예츨 결과
raw_res = pc_model.test(preprocssing_data)


predict_res = [""] * len(raw_res)
for idx, predict in enumerate(raw_res):
Expand All @@ -75,7 +81,7 @@ def test():
from shape_detect.controller import get_shape
@app.route('/predict_shape', methods =['GET'])
def predict_shape():
global image_path
global current_image_path
result = get_shape(current_image_path)

if result == -1:
Expand Down

0 comments on commit 76fcb58

Please sign in to comment.