diff --git a/.gitignore b/.gitignore index 4808a34..a552964 100644 --- a/.gitignore +++ b/.gitignore @@ -14,3 +14,4 @@ .glide/ vendor/ glide.lock +.idea/ diff --git a/.travis.yml b/.travis.yml index 8d9f7f2..d682c01 100644 --- a/.travis.yml +++ b/.travis.yml @@ -6,6 +6,4 @@ before_install: install: - glide install script: - - ls - - go build ./voice - - go build ./ocr + - go test ./... diff --git a/README.md b/README.md index e800e2e..b9b5fba 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,4 @@ -# Baidu-ai-go-sdk ![travis](https://travis-ci.org/chenqinghe/baidu-ai-go-sdk.svg?branch=master) +# Baidu-ai-go-sdk   ![travis](https://travis-ci.org/chenqinghe/baidu-ai-go-sdk.svg?branch=develop) 基于百度REST API封装的go语言sdk,提供简易友好的接口,让开发变得简单。 # Todo list @@ -6,9 +6,10 @@ - [x] 语音合成 - [x] 语音识别 -## 视觉技术 -### 文字识别 +### 视觉技术 + +#### 文字识别 - [x] 通用文字识别 - [x] 通用文字识别(含位置信息版) - [x] 通用文字识别(含生僻字版) @@ -19,33 +20,34 @@ - [x] 行驶证识别 - [ ] 表格文字识别 -### 人脸识别 + +#### 人脸识别 - [ ] 人脸检测 - [ ] 人脸对比 - [ ] 人脸查找 - [ ] 人脸库管理 - [ ] 公安验证 -### 图像审核 +#### 图像审核 - [ ] 图像审核 - [ ] GIF色情识别 - [ ] 图像审核组合服务接口 - [ ] 用户头像审核 -### 图像识别 +#### 图像识别 - [ ] 通用图像分析 - [ ] 细粒度图像识别 - [ ] 定制化图像识别 -### 图像搜索 +#### 图像搜索 - [ ] 相同图检索 - [ ] 相似图检索 - [ ] 商品检索 -## 自然语言 +### 自然语言 -### 语言处理基础技术 +#### 语言处理基础技术 - [ ] 词法分析 - [ ] 依存句法分析 - [ ] 词向量标识 @@ -58,10 +60,10 @@ - [ ] 词性标注 -### 理解与交互技术UNIT +#### 理解与交互技术UNIT - [ ] UNIT对话接口 -### 文本审核 +#### 文本审核 - [ ] 通用类文本反作弊 diff --git a/internal/client.go b/client.go similarity index 91% rename from internal/client.go rename to client.go index d3670cb..3284417 100644 --- a/internal/client.go +++ b/client.go @@ -1,4 +1,4 @@ -package internal +package gosdk import ( "errors" @@ -6,7 +6,14 @@ import ( "github.com/imroc/req" ) -const VOICE_AUTH_URL string = "https://openapi.baidu.com/oauth/2.0/token" +const VOICE_AUTH_URL = "https://openapi.baidu.com/oauth/2.0/token" + +//Authorizer 用于设置access_token +//可以通过RESTFul api的方式从百度方获取 +//有效期为一个月,可以存至数据库中然后从数据库中获取 +type Authorizer interface { + Authorize(*Client) error +} type Client struct { ClientID string @@ -31,17 +38,8 @@ type AuthResponseFailed struct { ErrorDescription string `json:"error_description"` //错误描述信息,帮助理解和解决发生的错误。 } -//Authorizer 用于设置access_token -//可以通过RESTFul api的方式从百度方获取 -//有效期为一个月,可以存至数据库中然后从数据库中获取 -type Authorizer interface { - Authorize(client *Client) error -} - type DefaultAuthorizer struct{} -type RestApiAuthorizer DefaultAuthorizer - func (da DefaultAuthorizer) Authorize(client *Client) error { resp, err := req.Post(VOICE_AUTH_URL, req.Param{ "grant_type": "client_credentials", @@ -67,7 +65,11 @@ func (client *Client) Auth() error { if client.AccessToken != "" { return nil } - return client.Authorizer.Authorize(client) + + if err := client.Authorizer.Authorize(client); err != nil { + return err + } + return nil } func (client *Client) SetAuther(auth Authorizer) { diff --git a/example/16k.pcm b/example/16k.pcm new file mode 100644 index 0000000..e5194ef Binary files /dev/null and b/example/16k.pcm differ diff --git a/example/hello.wav b/example/hello.wav deleted file mode 100644 index 6610f8f..0000000 Binary files a/example/hello.wav and /dev/null differ diff --git a/example/version/ocr/ocr.go b/example/version/ocr/ocr.go new file mode 100644 index 0000000..afe554d --- /dev/null +++ b/example/version/ocr/ocr.go @@ -0,0 +1,29 @@ +package main + +import ( + "fmt" + "github.com/chenqinghe/baidu-ai-go-sdk/version/ocr" + "os" +) + +const ( + // This Api Key and Api Secret is just for example, + // you should get your own first. + APIKEY = "5RijeBzVjQ82uPx8gxGGfeNXlfRt7yH6" + APISECRET = "keiyq3oKrkYsSPUcrf0gtRKneeTxjuqV" +) + +func main() { + + client := ocr.NewOCRClient(APIKEY, APISECRET) + + f, err := os.OpenFile("ocr.jpg", os.O_RDONLY, 0777) + if err != nil { + panic(err) + } + rs, err := client.GeneralRecognizeBasic(f) + if err != nil { + panic(err) + } + fmt.Println(string(rs)) +} diff --git a/ocr/ocr.jpg b/example/version/ocr/ocr.jpg similarity index 100% rename from ocr/ocr.jpg rename to example/version/ocr/ocr.jpg diff --git a/example/voice.go b/example/voice.go index 19f5410..67422d4 100644 --- a/example/voice.go +++ b/example/voice.go @@ -4,8 +4,6 @@ import ( "github.com/chenqinghe/baidu-ai-go-sdk/voice" "log" "os" - "io/ioutil" - "encoding/base64" "fmt" ) @@ -24,7 +22,7 @@ func TextToSpeech() { log.Fatal(err) } - f, err := os.OpenFile("hello.mp3", os.O_CREATE|os.O_WRONLY, 0777) + f, err := os.OpenFile("hello.mp3", os.O_CREATE|os.O_WRONLY, 0644) if err != nil { log.Fatal(err) } @@ -43,30 +41,18 @@ func SpeechToText() { log.Fatal(err) } - f, err := os.OpenFile("hello.wav", os.O_RDONLY, 0666) + f, err := os.OpenFile("16k.pcm", os.O_RDONLY, 0666) if err != nil { log.Fatal(err) } - - fi, err1 := ioutil.ReadAll(f) - if err1 != nil { - log.Fatal(err1) - } - afterBase64Str := base64.StdEncoding.EncodeToString(fi) - fiLen := len(fi) - param := voice.ASRParams{ - Format: "wav", - Rate: 16000, - Channel: 1, - Cuid: "12312312112", - Token: client.AccessToken, - Lan: "zh", - Speech: afterBase64Str, - Len: fiLen, - } - rs, err2 := client.SpeechToText(param) - if err2 != nil { - log.Fatal(err2) + + rs, err := client.SpeechToText( + f, + voice.Format("pcm"), + voice.Channel(1), + ) + if err != nil { + log.Fatal(err) } fmt.Println(rs) } diff --git a/ocr/general.go b/ocr/general.go deleted file mode 100644 index a7905c9..0000000 --- a/ocr/general.go +++ /dev/null @@ -1,117 +0,0 @@ -package ocr - -import ( - "encoding/base64" - - sdk "github.com/chenqinghe/baidu-ai-go-sdk/internal" - "github.com/imroc/req" -) - -const ( - OCR_GENERAL_BASIC_URL string = "https://aip.baidubce.com/rest/2.0/ocr/v1/general_basic" - OCR_GENERAL_WITH_LOCATION_URL = "https://aip.baidubce.com/rest/2.0/ocr/v1/general" - OCR_GENERAL_ENHANCED_URL = "https://aip.baidubce.com/rest/2.0/ocr/v1/general_enhanced" -) - -var defaultGeneralBasicParams = map[string]string{ - "image": "", //图像数据,base64编码,要求base64编码后大小不超过4M,最短边至少15px,最长边最大4096px,支持jpg/png/bmp格式 - "language_type": "CHN_ENG", //识别语言类型,默认为CHN_ENG。可选值包括: - CHN_ENG:中英文混合; - ENG:英文; - POR:葡萄牙语; - FRE:法语; - GER:德语; - ITA:意大利语; - SPA:西班牙语; - RUS:俄语; - JAP:日语 - "detect_direction": "false", //是否检测图像朝向,默认不检测,即:false。朝向是指输入图像是正常方向、逆时针旋转90/180/270度。可选值包括: - true:检测朝向; - false:不检测朝向。 - "detect_language": "false", //是否检测语言,默认不检测。当前支持(中文、英语、日语、韩语) -} - -var defaultGeneralWithLocationParams = map[string]string{ - "image": "", //图像数据,base64编码,要求base64编码后大小不超过4M,最短边至少15px,最长边最大4096px,支持jpg/png/bmp格式 - "recognize_granularity": "big", //是否定位单字符位置,big:不定位单字符位置,默认值;small:定位单字符位置 - "language_type": "CHN_ENG", //识别语言类型,默认为CHN_ENG。可选值包括: - CHN_ENG:中英文混合; - ENG:英文; - POR:葡萄牙语; - FRE:法语; - GER:德语; - ITA:意大利语; - SPA:西班牙语; - RUS:俄语; - JAP:日语 - "detect_direction": "false", //是否检测图像朝向,默认不检测,即:false。朝向是指输入图像是正常方向、逆时针旋转90/180/270度。可选值包括:- true:检测朝向; - false:不检测朝向 - "detect_language": "false", //是否检测语言,默认不检测。当前支持(中文、英语、日语、韩语) - "vertexes_location": "false", //是否返回文字外接多边形顶点位置,不支持单字位置。默认为false - "probability": "false", //是否返回识别结果中每一行的置信度 -} - -var defaultDeneralEnhancedParams = map[string]string{ - "image": "", //图像数据,base64编码,要求base64编码后大小不超过4M,最短边至少15px,最长边最大4096px,支持jpg/png/bmp格式 - "language_type": "CHN_ENG", //识别语言类型,默认为CHN_ENG。可选值包括: - CHN_ENG:中英文混合; - ENG:英文; - POR:葡萄牙语; - FRE:法语; - GER:德语; - ITA:意大利语; - SPA:西班牙语; - RUS:俄语; - JAP:日语 - "detect_direction": "false", //是否检测图像朝向,默认不检测,即:false。朝向是指输入图像是正常方向、逆时针旋转90/180/270度。可选值包括:- true:检测朝向; - false:不检测朝向 - "detect_language": "false", //是否检测语言,默认不检测。当前支持(中文、英语、日语、韩语) - "probability": "false", //是否返回识别结果中每一行的置信度 -} - -type OCRClient struct { - *sdk.Client -} - -func NewOCRClient(apiKey, secretKey string) *OCRClient { - return &OCRClient{ - Client: sdk.NewClient(apiKey, secretKey), - } -} - -//GeneralRecognizeBasic 通用文字识别 -//img 图片二进制数据 -//conf 请求参数 -func (oc *OCRClient) GeneralRecognizeBasic(img []byte, conf map[string]string) ([]byte, error) { - if err := oc.Auth(); err != nil { - return nil, err - } - encodedImgStr := base64.StdEncoding.EncodeToString(img) - conf["image"] = encodedImgStr - - conf = parseParams(defaultGeneralBasicParams, conf) - - var url string = OCR_GENERAL_BASIC_URL + "?access_token=" + oc.AccessToken - - return doRequest(url, conf) -} - -//GeneralRecognizeWithLocation 通用文字识别(含位置信息) -func (oc *OCRClient) GeneralRecognizeWithLocation(img []byte, conf map[string]string) ([]byte, error) { - if err := oc.Auth(); err != nil { - return nil, err - } - encodedImgStr := base64.StdEncoding.EncodeToString(img) - conf["image"] = encodedImgStr - conf = parseParams(defaultGeneralWithLocationParams, conf) - - var url string = OCR_GENERAL_WITH_LOCATION_URL + "?access_token=" + oc.AccessToken - - return doRequest(url, conf) - -} - -//GeneralRecognizeEnhanced 通用文字识别(含生僻字) -func (oc *OCRClient) GeneralRecognizeEnhanced(img []byte, conf map[string]string) ([]byte, error) { - if err := oc.Auth(); err != nil { - return nil, err - } - encodedImgStr := base64.StdEncoding.EncodeToString(img) - conf["image"] = encodedImgStr - - conf = parseParams(defaultDeneralEnhancedParams, conf) - - url := OCR_GENERAL_ENHANCED_URL + "?access_token=" + oc.AccessToken - - return doRequest(url, conf) - -} - -func parseParams(def, need map[string]string) map[string]string { - for key, _ := range def { - if val, ok := need[key]; ok { - def[key] = val - } - } - return def -} - -func doRequest(url string, params map[string]interface{}) (rs []byte, err error) { - - resp, err := req.Post(url, req.Param(params), req.Header{"Content-Type": "application/x-www-form-urlencoded"}) - if err != nil { - return - } - rs, err = resp.ToBytes() - return - -} diff --git a/ocr/particular.go b/ocr/particular.go deleted file mode 100644 index 591ef4b..0000000 --- a/ocr/particular.go +++ /dev/null @@ -1,83 +0,0 @@ -package ocr - -const ( - OCR_WEBIMAGE_URL string = "https://aip.baidubce.com/rest/2.0/ocr/v1/webimage" - OCR_IDCARD_URL = "https://aip.baidubce.com/rest/2.0/ocr/v1/idcard" - OCR_BANKCARD_URL = "https://aip.baidubce.com/rest/2.0/ocr/v1/bankcard" - OCR_DRIVERLICENSE_URL = "https://aip.baidubce.com/rest/2.0/ocr/v1/driving_license" - OCR_VEHICLELICENSE_URL = "https://aip.baidubce.com/rest/2.0/ocr/v1/vehicle_license" - OCR_LICENSEPLATE_URL = "https://aip.baidubce.com/rest/2.0/ocr/v1/license_plate" - OCR_FORM_URL = "https://aip.baidubce.com/rest/2.0/solution/v1/form_ocr/request" -) - -var ( - defaultWebimgParams = defaultDeneralEnhancedParams - defaultIdcardParams = map[string]string{ - "image": "", //图像数据,base64编码,要求base64编码后大小不超过4M,最短边至少15px,最长边最大4096px,支持jpg/png/bmp格式 - "detect_direction": "false", //是否检测图像朝向,默认不检测,即:false。朝向是指输入图像是正常方向、逆时针旋转90/180/270度。可选值包括: - true:检测朝向; - false:不检测朝向。 - "id_card_side": "front", //front:身份证正面;back:身份证背面 - "detect_risk": "false", //是否开启身份证风险类型(身份证复印件、临时身份证、身份证翻拍、修改过的身份证)功能,默认不开启,即:false。可选值:true-开启;false-不开启 - } - defaultBankcardParams = map[string]string{ - "image": "", //图像数据,base64编码,要求base64编码后大小不超过4M,最短边至少15px,最长边最大4096px,支持jpg/png/bmp格式 - } - defaultDriverLicenseParams = map[string]string{ - "image": "", //图像数据,base64编码,要求base64编码后大小不超过4M,最短边至少15px,最长边最大4096px,支持jpg/png/bmp格式 - "detect_direction": "false", //是否检测图像朝向,默认不检测,即:false。朝向是指输入图像是正常方向、逆时针旋转90/180/270度。可选值包括: - true:检测朝向; - false:不检测朝向。 - } - defaultVehicleLicenseParams = map[string]string{ - "image": "", //图像数据,base64编码,要求base64编码后大小不超过4M,最短边至少15px,最长边最大4096px,支持jpg/png/bmp格式 - "detect_direction": "false", //是否检测图像朝向,默认不检测,即:false。朝向是指输入图像是正常方向、逆时针旋转90/180/270度。可选值包括: - true:检测朝向; - false:不检测朝向。 - "accuracy": "", //normal 使用快速服务,1200ms左右时延;缺省或其它值使用高精度服务,1600ms左右时延 - } - defaultLicensePlateParams = defaultBankcardParams - defaultFormParams = defaultBankcardParams -) - -func (oc *OCRClient) WebImageRecognize(img []byte, conf map[string]string) ([]byte, error) { - - return oc.generalOperate(img, OCR_WEBIMAGE_URL, conf, defaultWebimgParams) - -} - -func (oc *OCRClient) IdcardRecognize(img []byte, conf map[string]string) ([]byte, error) { - - return oc.generalOperate(img, OCR_IDCARD_URL, conf, defaultIdcardParams) -} - -func (oc *OCRClient) BankcardRecognize(img []byte, conf map[string]string) ([]byte, error) { - - return oc.generalOperate(img, OCR_BANKCARD_URL, conf, defaultBankcardParams) - -} - -func (oc *OCRClient) DriverLicenseRecognize(img []byte, conf map[string]string) ([]byte, error) { - - return oc.generalOperate(img, OCR_DRIVERLICENSE_URL, conf, defaultDriverLicenseParams) -} - -func (oc *OCRClient) VehicleLicenseRecognize(img []byte, conf map[string]string) ([]byte, error) { - - return oc.generalOperate(img, OCR_VEHICLELICENSE_URL, conf, defaultVehicleLicenseParams) -} - -func (oc *OCRClient) LicensePlateRecognize(img []byte, conf map[string]string) ([]byte, error) { - - return oc.generalOperate(img, OCR_LICENSEPLATE_URL, conf, defaultLicensePlateParams) -} - -func (oc *OCRClient) FromdataRecognize(img []byte, conf map[string]string) ([]byte, error) { - - return oc.generalOperate(img, OCR_FORM_URL, conf, defaultFormParams) -} - -func (oc *OCRClient) generalOperate(img []byte, baseurl string, conf, def map[string]string) ([]byte, error) { - if err := oc.Auth(); err != nil { - return nil, err - } - conf = parseParams(def, conf) - - url := baseurl + "?access_token=" + oc.AccessToken - - return doRequest(url, conf) -} diff --git a/version/ocr/default.go b/version/ocr/default.go new file mode 100644 index 0000000..23a1a1d --- /dev/null +++ b/version/ocr/default.go @@ -0,0 +1,57 @@ +package ocr + +var defaultGeneralBasicParams = map[string]interface{}{ + "image": "", //图像数据,base64编码,要求base64编码后大小不超过4M,最短边至少15px,最长边最大4096px,支持jpg/png/bmp格式 + "url": "", //图片完整URL,URL长度不超过1024字节,URL对应的图片base64编码后大小不超过4M,最短边至少15px,最长边最大4096px,支持jpg/png/bmp格式,当image字段存在时url字段失效,不支持https的图片链接 + "language_type": "CHN_ENG", //识别语言类型,默认为CHN_ENG。可选值包括: - CHN_ENG:中英文混合; - ENG:英文; - POR:葡萄牙语; - FRE:法语; - GER:德语; - ITA:意大利语; - SPA:西班牙语; - RUS:俄语; - JAP:日语 + "detect_direction": "false", //是否检测图像朝向,默认不检测,即:false。朝向是指输入图像是正常方向、逆时针旋转90/180/270度。可选值包括: - true:检测朝向; - false:不检测朝向。 + "detect_language": "false", //是否检测语言,默认不检测。当前支持(中文、英语、日语、韩语) + "probability": "false", //是否返回识别结果中每一行的置信度 +} + +var defaultGeneralWithLocationParams = map[string]interface{}{ + "image": "", //图像数据,base64编码,要求base64编码后大小不超过4M,最短边至少15px,最长边最大4096px,支持jpg/png/bmp格式 + "url": "", //图片完整URL,URL长度不超过1024字节,URL对应的图片base64编码后大小不超过4M,最短边至少15px,最长边最大4096px,支持jpg/png/bmp格式,当image字段存在时url字段失效,不支持https的图片链接 + "recognize_granularity": "big", //是否定位单字符位置,big:不定位单字符位置,默认值;small:定位单字符位置 + "language_type": "CHN_ENG", //识别语言类型,默认为CHN_ENG。可选值包括: - CHN_ENG:中英文混合; - ENG:英文; - POR:葡萄牙语; - FRE:法语; - GER:德语; - ITA:意大利语; - SPA:西班牙语; - RUS:俄语; - JAP:日语 + "detect_direction": "false", //是否检测图像朝向,默认不检测,即:false。朝向是指输入图像是正常方向、逆时针旋转90/180/270度。可选值包括:- true:检测朝向; - false:不检测朝向 + "detect_language": "false", //是否检测语言,默认不检测。当前支持(中文、英语、日语、韩语) + "vertexes_location": "false", //是否返回文字外接多边形顶点位置,不支持单字位置。默认为false + "probability": "false", //是否返回识别结果中每一行的置信度 +} + +var defaultDeneralEnhancedParams = map[string]interface{}{ + "image": "", //图像数据,base64编码,要求base64编码后大小不超过4M,最短边至少15px,最长边最大4096px,支持jpg/png/bmp格式 + "language_type": "CHN_ENG", //识别语言类型,默认为CHN_ENG。可选值包括: - CHN_ENG:中英文混合; - ENG:英文; - POR:葡萄牙语; - FRE:法语; - GER:德语; - ITA:意大利语; - SPA:西班牙语; - RUS:俄语; - JAP:日语 + "detect_direction": "false", //是否检测图像朝向,默认不检测,即:false。朝向是指输入图像是正常方向、逆时针旋转90/180/270度。可选值包括:- true:检测朝向; - false:不检测朝向 + "detect_language": "false", //是否检测语言,默认不检测。当前支持(中文、英语、日语、韩语) + "probability": "false", //是否返回识别结果中每一行的置信度 +} + +var defaultWebimgParams = defaultDeneralEnhancedParams + +var defaultIdcardParams = map[string]interface{}{ + "image": "", //图像数据,base64编码,要求base64编码后大小不超过4M,最短边至少15px,最长边最大4096px,支持jpg/png/bmp格式 + "detect_direction": "false", //是否检测图像朝向,默认不检测,即:false。朝向是指输入图像是正常方向、逆时针旋转90/180/270度。可选值包括: - true:检测朝向; - false:不检测朝向。 + "id_card_side": "front", //front:身份证正面;back:身份证背面 + "detect_risk": "false", //是否开启身份证风险类型(身份证复印件、临时身份证、身份证翻拍、修改过的身份证)功能,默认不开启,即:false。可选值:true-开启;false-不开启 +} + +var defaultBankcardParams = map[string]interface{}{ + "image": "", //图像数据,base64编码,要求base64编码后大小不超过4M,最短边至少15px,最长边最大4096px,支持jpg/png/bmp格式 +} + +var defaultDriverLicenseParams = map[string]interface{}{ + "image": "", //图像数据,base64编码,要求base64编码后大小不超过4M,最短边至少15px,最长边最大4096px,支持jpg/png/bmp格式 + "detect_direction": "false", //是否检测图像朝向,默认不检测,即:false。朝向是指输入图像是正常方向、逆时针旋转90/180/270度。可选值包括: - true:检测朝向; - false:不检测朝向。 +} + +var defaultVehicleLicenseParams = map[string]interface{}{ + "image": "", //图像数据,base64编码,要求base64编码后大小不超过4M,最短边至少15px,最长边最大4096px,支持jpg/png/bmp格式 + "detect_direction": "false", //是否检测图像朝向,默认不检测,即:false。朝向是指输入图像是正常方向、逆时针旋转90/180/270度。可选值包括: - true:检测朝向; - false:不检测朝向。 + "accuracy": "", //normal 使用快速服务,1200ms左右时延;缺省或其它值使用高精度服务,1600ms左右时延 +} + +var defaultLicensePlateParams = defaultBankcardParams + +var defaultFormParams = defaultBankcardParams diff --git a/version/ocr/image.go b/version/ocr/image.go new file mode 100644 index 0000000..a663b18 --- /dev/null +++ b/version/ocr/image.go @@ -0,0 +1,24 @@ +package ocr + +import ( + "image" + "io" +) + +type Size struct { + Height int + Width int +} + +func getImageSize(reader io.Reader) (*Size, error) { + img, _, err := image.Decode(reader) + if err != nil { + return nil, err + } + bounds := img.Bounds() + size := &Size{ + Width: bounds.Dx(), + Height: bounds.Dy(), + } + return size, nil +} diff --git a/version/ocr/ocr.go b/version/ocr/ocr.go new file mode 100644 index 0000000..1f7f476 --- /dev/null +++ b/version/ocr/ocr.go @@ -0,0 +1,126 @@ +package ocr + +import ( + "encoding/base64" + "github.com/chenqinghe/baidu-ai-go-sdk" + "io" + "io/ioutil" +) + +const ( + OCR_GENERAL_BASIC_URL = "https://aip.baidubce.com/rest/2.0/ocr/v1/general_basic" + OCR_GENERAL_WITH_LOCATION_URL = "https://aip.baidubce.com/rest/2.0/ocr/v1/general" + OCR_GENERAL_ENHANCED_URL = "https://aip.baidubce.com/rest/2.0/ocr/v1/general_enhanced" +) + +const ( + OCR_WEBIMAGE_URL = "https://aip.baidubce.com/rest/2.0/ocr/v1/webimage" + OCR_IDCARD_URL = "https://aip.baidubce.com/rest/2.0/ocr/v1/idcard" + OCR_BANKCARD_URL = "https://aip.baidubce.com/rest/2.0/ocr/v1/bankcard" + OCR_DRIVERLICENSE_URL = "https://aip.baidubce.com/rest/2.0/ocr/v1/driving_license" + OCR_VEHICLELICENSE_URL = "https://aip.baidubce.com/rest/2.0/ocr/v1/vehicle_license" + OCR_LICENSEPLATE_URL = "https://aip.baidubce.com/rest/2.0/ocr/v1/license_plate" + OCR_FORM_URL = "https://aip.baidubce.com/rest/2.0/solution/v1/form_ocr/request" +) + +type OCRClient struct { + *gosdk.Client +} + +func NewOCRClient(apiKey, secretKey string) *OCRClient { + return &OCRClient{ + Client: gosdk.NewClient(apiKey, secretKey), + } +} + +//GeneralRecognizeBasic 通用文字识别 +//img 图片二进制数据 +//conf 请求参数 +func (oc *OCRClient) GeneralRecognizeBasic(imageReader io.Reader, params ...RequestParam) ([]byte, error) { + + return oc.ocr(imageReader, OCR_GENERAL_BASIC_URL, defaultGeneralBasicParams, params...) + +} + +//GeneralRecognizeWithLocation 通用文字识别(含位置信息) +func (oc *OCRClient) GeneralRecognizeWithLocation(imageReader io.Reader, params ...RequestParam) ([]byte, error) { + + return oc.ocr(imageReader, OCR_GENERAL_WITH_LOCATION_URL, defaultGeneralWithLocationParams, params...) + +} + +//GeneralRecognizeEnhanced 通用文字识别(含生僻字) +func (oc *OCRClient) GeneralRecognizeEnhanced(imageReader io.Reader, params ...RequestParam) ([]byte, error) { + + return oc.ocr(imageReader, OCR_GENERAL_ENHANCED_URL, defaultDeneralEnhancedParams, params...) + +} + +func (oc *OCRClient) WebImageRecognize(imageReader io.Reader, params ...RequestParam) ([]byte, error) { + + return oc.ocr(imageReader, OCR_WEBIMAGE_URL, defaultWebimgParams, params...) + +} + +func (oc *OCRClient) IdcardRecognize(imageReader io.Reader, params ...RequestParam) ([]byte, error) { + + return oc.ocr(imageReader, OCR_IDCARD_URL, defaultIdcardParams, params...) + +} + +func (oc *OCRClient) BankcardRecognize(imageReader io.Reader, params ...RequestParam) ([]byte, error) { + + return oc.ocr(imageReader, OCR_BANKCARD_URL, defaultBankcardParams, params...) + +} + +func (oc *OCRClient) DriverLicenseRecognize(imageReader io.Reader, params ...RequestParam) ([]byte, error) { + + return oc.ocr(imageReader, OCR_DRIVERLICENSE_URL, defaultDriverLicenseParams, params...) + +} + +func (oc *OCRClient) VehicleLicenseRecognize(imageReader io.Reader, params ...RequestParam) ([]byte, error) { + + return oc.ocr(imageReader, OCR_VEHICLELICENSE_URL, defaultVehicleLicenseParams, params...) + +} + +func (oc *OCRClient) LicensePlateRecognize(imageReader io.Reader, params ...RequestParam) ([]byte, error) { + + return oc.ocr(imageReader, OCR_LICENSEPLATE_URL, defaultLicensePlateParams, params...) + +} + +func (oc *OCRClient) FromdataRecognize(imageReader io.Reader, params ...RequestParam) ([]byte, error) { + + return oc.ocr(imageReader, OCR_FORM_URL, defaultFormParams, params...) + +} + +func (oc *OCRClient) ocr(imageReader io.Reader, url string, def map[string]interface{}, params ...RequestParam) ([]byte, error) { + requestParams, err := parseRequestParam(imageReader, def, params...) + if err != nil { + return nil, err + } + + return oc.doRequest(url, requestParams) +} + +func parseRequestParam(imageReader io.Reader, def map[string]interface{}, params ...RequestParam) (map[string]interface{}, error) { + + imageBytes, err := ioutil.ReadAll(imageReader) + if err != nil { + return nil, err + } + imageBase64Str := base64.StdEncoding.EncodeToString(imageBytes) + + def["image"] = imageBase64Str + + for _, fn := range params { + fn(def) + } + + return def, nil + +} diff --git a/version/ocr/param.go b/version/ocr/param.go new file mode 100644 index 0000000..d26054d --- /dev/null +++ b/version/ocr/param.go @@ -0,0 +1,111 @@ +package ocr + +type RequestParam func(map[string]interface{}) + +//识别语言类型,默认为CHN_ENG。 +func LanguageType(lang string) RequestParam { + options := []string{ + "CHN_ENG", + "ENG", + "POR", + "FRE", + "GER", + "ITA", + "SPA", + "RUS", + "JAP", + "KOR", + } + + illegal := true + for _, v := range options { + if v == lang { + illegal = false + break + } + } + + if illegal { + lang = "CHN_ENG" + } + return func(m map[string]interface{}) { + m["language_type"] = lang + } +} + +//是否检测图像朝向,默认不检测,即:false。朝向是指输入图像是正常方向、逆时针旋转90/180/270度。可选值包括: +//- true:检测朝向; +//- false:不检测朝向。 +func DetectDirection() RequestParam { + return func(m map[string]interface{}) { + m["detect_direction"] = true + } +} + +//是否检测语言,默认不检测。 +//当前支持(中文、英语、日语、韩语) +func DetectLanguage() RequestParam { + return func(m map[string]interface{}) { + m["detect_language"] = true + } +} + +//是否返回识别结果中每一行的置信度 +func WithProbability() RequestParam { + return func(m map[string]interface{}) { + m["probability"] = true + } +} + +//是否定位单字符位置,big:不定位单字符位置,默认值;small:定位单字符位置 +func RecognizeGranularity() RequestParam { + return func(m map[string]interface{}) { + m["recognize_granularity"] = "small" + } +} + +//是否返回文字外接多边形顶点位置,不支持单字位置。默认为false +func WithVertexesLocation() RequestParam { + return func(m map[string]interface{}) { + m["vertexes_location"] = true + } +} + +//front:身份证含照片的一面;back:身份证带国徽的一面 +func IDCardSide(side string) RequestParam { + return func(m map[string]interface{}) { + m["id_card_side"] = side + } +} + +//是否开启身份证风险类型(身份证复印件、临时身份证、身份证翻拍、修改过的身份证)功能,默认不开启,即:false。 +// 可选值:true-开启;false-不开启 +func DetectRisk() RequestParam { + return func(m map[string]interface{}) { + m["detect_risk"] = true + } +} + +//true: 归一化格式输出;false 或无此参数按非归一化格式输出 +func UnifiedValidPeriod() RequestParam { + return func(m map[string]interface{}) { + m["unified_valid_period"] = true + } +} + +//normal 使用快速服务,1200ms左右时延;缺省或其它值使用高精度服务,1600ms左右时延 +func Accuracy(opt string) RequestParam { + if opt != "normal" && opt != "high" { + opt = "normal" + } + return func(m map[string]interface{}) { + m["accuracy"] = "normal" + } +} + +//是否检测多张车牌,默认为false,当置为true的时候可以对一张图片内的多张车牌进行识别 +func MultiDetect() RequestParam { + return func(m map[string]interface{}) { + m["multi_detect"] = true + } +} diff --git a/ocr/readme.md b/version/ocr/readme.md similarity index 100% rename from ocr/readme.md rename to version/ocr/readme.md diff --git a/version/ocr/request.go b/version/ocr/request.go new file mode 100644 index 0000000..44b1980 --- /dev/null +++ b/version/ocr/request.go @@ -0,0 +1,22 @@ +package ocr + +import "github.com/imroc/req" + +func (oc *OCRClient) doRequest(url string, params map[string]interface{}) (rs []byte, err error) { + + if err := oc.Auth(); err != nil { + return nil, err + } + + header := req.Header{ + "Content-Type": "application/x-www-form-urlencoded", + } + + url += "?access_token=" + oc.AccessToken + + resp, err := req.Post(url, req.Param(params), header) + if err != nil { + return + } + return resp.ToBytes() +} diff --git a/voice/asr.go b/voice/asr.go new file mode 100644 index 0000000..547f4f0 --- /dev/null +++ b/voice/asr.go @@ -0,0 +1,131 @@ +package voice + +import ( + "encoding/base64" + "errors" + "io" + "io/ioutil" + + "net" + + "fmt" + "github.com/imroc/req" +) + +const ASR_URL = "http://vop.baidu.com/server_api" + +//语音识别响应信息 +type ASRResponse struct { + CorpusNo string `json:"corpus_no"` + ERRMSG string `json:"err_msg"` + ERRNO int `json:"err_no"` + Result []string `json:"result"` + SN string `json:"sn"` +} + +//语音识别参数 +type ASRParams struct { + Format string `json:"format"` //语音的格式,pcm 或者 wav 或者 amr。不区分大小写 + Rate int `json:"rate"` //采样率,支持 8000 或者 16000 + Channel int `json:"channel"` //声道数,仅支持单声道,请填写固定值 1 + Cuid string `json:"cuid"` //用户唯一标识,用来区分用户,计算UV值。建议填写能区分用户的机器 MAC 地址或 IMEI 码,长度为60字符以内 + Token string `json:"token"` //开放平台获取到的开发者access_token + Language string `json:"lan"` //语种选择,默认中文(zh)。 中文=zh、粤语=ct、英文=en,不区分大小写 + Speech string `json:"speech"` //真实的语音数据 ,需要进行base64 编码。与len参数连一起使用 + Length int `json:"len"` //原始语音长度,单位字节 +} + +type ASRParam func(params *ASRParams) + +func Format(fmt string) ASRParam { + + if fmt != "pcm" && fmt != "wav" && fmt != "amr" { + fmt = "pcm" + } + return func(params *ASRParams) { + params.Format = fmt + } +} + +func Rate(rate int) ASRParam { + if rate != 8000 && rate != 16000 { + rate = 8000 + } + return func(params *ASRParams) { + params.Rate = rate + } +} + +func Channel(c int) ASRParam { + return func(params *ASRParams) { + params.Channel = 1 //固定值1 + } +} + +func Language(lang string) ASRParam { + if lang != "zh" && lang != "ct" && lang != "en" { + lang = "zh" + } + return func(params *ASRParams) { + params.Language = lang + } +} + +////SpeechToText 语音识别,将语音翻译成文字 +func (vc *VoiceClient) SpeechToText(reader io.Reader, params ...ASRParam) ([]string, error) { + content, err := ioutil.ReadAll(reader) + if err != nil { + return nil, err + } + if len(content) > 10*MB { + return nil, errors.New("文件大小不能超过10M") + } + + spch := base64.StdEncoding.EncodeToString(content) + + var cuid string + netitfs, err := net.Interfaces() + if err != nil { + cuid = "anonymous" + } else { + cuid = netitfs[0].HardwareAddr.String() + } + + asrParams := &ASRParams{ + Format: "pcm", + Rate: 8000, + Channel: 1, + Cuid: cuid, + Token: vc.AccessToken, + Language: "zh", + Speech: spch, + Length: len(content), + } + + for _, fn := range params { + fn(asrParams) + } + + header := req.Header{ + "Content-Type": "application/json", + } + + resp, err := req.Post(ASR_URL, header, req.BodyJSON(asrParams)) + if err != nil { + return nil, err + } + + fmt.Println(resp.String()) + + var asrResponse *ASRResponse + if err := resp.ToJSON(asrResponse); err != nil { + return nil, err + } + + if asrResponse.ERRNO != 0 { + return nil, errors.New("调用服务失败:" + asrResponse.ERRMSG) + } + + return asrResponse.Result, nil + +} diff --git a/voice/client.go b/voice/client.go new file mode 100644 index 0000000..f5da646 --- /dev/null +++ b/voice/client.go @@ -0,0 +1,22 @@ +package voice + +import "github.com/chenqinghe/baidu-ai-go-sdk" + +const ( + B = 1 << (10 * iota) + KB + MB + GB + TB + PB +) + +type VoiceClient struct { + *gosdk.Client +} + +func NewVoiceClient(apiKey, apiSecret string) *VoiceClient { + return &VoiceClient{ + Client: gosdk.NewClient(apiKey, apiSecret), + } +} diff --git a/voice/resp.mp3 b/voice/resp.mp3 new file mode 100644 index 0000000..9fd2a28 Binary files /dev/null and b/voice/resp.mp3 differ diff --git a/voice/result.go b/voice/result.go new file mode 100644 index 0000000..ef01e8d --- /dev/null +++ b/voice/result.go @@ -0,0 +1 @@ +package voice diff --git a/voice/tts.go b/voice/tts.go new file mode 100644 index 0000000..c25e1ad --- /dev/null +++ b/voice/tts.go @@ -0,0 +1,150 @@ +// 语音处理 +// 利用百度RESTFul API 进行语音及文字的相互转换 +package voice + +import ( + "errors" + + "io/ioutil" + + "net" + + "encoding/json" + "github.com/imroc/req" +) + +const TTS_URL = "http://tsn.baidu.com/text2audio" + +var ( + ErrTextTooLong = errors.New("The input string is too long") +) + +type TTSParams struct { + Text string `json:"tex"` + Token string `json:"tok"` + Cuid string `json:"cuid"` + ClientType int `json:"ctp"` + Language string `json:"lan"` + Speed int `json:"spd"` + Pitch int `json:"pit"` + Volume int `json:"vol"` + Person int `json:"per"` +} + +type TTSParam func(params *TTSParams) + +func Speed(spd int) TTSParam { + if spd > 9 { + spd = 9 + } + if spd < 0 { + spd = 0 + } + return func(p *TTSParams) { + p.Speed = spd + } +} + +func Pitch(pit int) TTSParam { + if pit > 9 { + pit = 9 + } + if pit < 0 { + pit = 0 + } + return func(p *TTSParams) { + p.Pitch = pit + } +} + +func Volume(vol int) TTSParam { + if vol > 15 { + vol = 15 + } + if vol < 0 { + vol = 0 + } + return func(p *TTSParams) { + p.Volume = vol + } +} + +func Person(per int) TTSParam { + if per != 0 && per != 1 && per != 3 && per != 4 { + per = 0 + } + return func(p *TTSParams) { + p.Person = per + } +} + +//TextToSpeech 语音合成,将文字转换为语音 +func (vc *VoiceClient) TextToSpeech(txt string, params ...TTSParam) ([]byte, error) { + + if len(txt) >= 1024 { + return nil, ErrTextTooLong + } + if err := vc.Auth(); err != nil { + return nil, err + } + + var cuid string + netitfs, err := net.Interfaces() + if err != nil { + cuid = "anonymous" + } else { + cuid = netitfs[0].HardwareAddr.String() + } + + ttsparams := &TTSParams{ + Text: txt, + Token: vc.AccessToken, + Cuid: cuid, + ClientType: 1, + Language: "zh", + Speed: 5, + Pitch: 5, + Volume: 5, + Person: 0, + } + + for _, param := range params { + param(ttsparams) + } + + t, err := json.Marshal(ttsparams) + if err != nil { + return nil, errors.New("serialize failed: " + err.Error()) + } + var p = req.Param{} + if err := json.Unmarshal(t, &p); err != nil { + return nil, err + } + + resp, err := req.Post(TTS_URL, p) + if err != nil { + return nil, err + } + + //通过Content-Type的头部来确定是否服务端合成成功。 + //http://ai.baidu.com/docs#/TTS-API/top + respHeader := resp.Response().Header + contentType, ok := respHeader["Content-Type"] + if !ok { + return nil, errors.New("No Content-Type Set.") + } + if contentType[0] == "audio/mp3" { + respBody, err := ioutil.ReadAll(resp.Response().Body) + if err != nil { + return nil, err + } + return respBody, nil + } else { + respStr, err := resp.ToString() + if err != nil { + return nil, err + } + return nil, errors.New("调用服务失败:" + respStr) + } + +} diff --git a/voice/voice.go b/voice/voice.go deleted file mode 100644 index e11e676..0000000 --- a/voice/voice.go +++ /dev/null @@ -1,170 +0,0 @@ -// 语音处理 -// 利用百度RESTFul API 进行语音及文字的相互转换 -package voice - -import ( - "errors" - - "io/ioutil" - - sdk "github.com/chenqinghe/baidu-ai-go-sdk/internal" - "github.com/imroc/req" - "math/rand" - "net" - "strconv" - "strings" -) - -const ( - TTS_URL string = "http://tsn.baidu.com/text2audio" - ASR_URL string = "http://vop.baidu.com/server_api" -) -const ( - B int = 1 << (10 * iota) - KB - MB -) - -var ( - ErrNoTTSConfig = errors.New("No TTSConfig.please set TTSConfig correctlly first or call method UseDefaultTTSConfig") - ErrTextTooLong = errors.New("The input string is too long") -) - -//VoiceClient 代表一个语音服务应用 -type VoiceClient struct { - *sdk.Client - TTSConfig *TTSConfig -} - -//语音合成参数 -type TTSConfig struct { - SPD int //语速,取值0-9,默认为5中语速 - PIT int //音调,取值0-9,默认为5中语调 - VOL int //音量,取值0-15,默认为5中音量 - PER int //发音人选择, 0为普通女声,1为普通男声,3为情感合成-度逍遥,4为情感合成-度丫丫,默认为普通女声 -} - -var defaultTTSConfig = &TTSConfig{ - SPD: 5, - PIT: 5, - VOL: 5, - PER: 0, -} - -//语音识别响应信息 -type ASRResponse struct { - CorpusNo string `json:"corpus_no"` - ERRMSG string `json:"err_msg"` - ERRNO int `json:"err_no"` - Result []string `json:"result"` - SN string `json:"sn"` -} - -//语音识别参数 -type ASRParams struct { - Format string `json:"format"` //语音的格式,pcm 或者 wav 或者 amr。不区分大小写 - Rate int `json:"rate"` //采样率,支持 8000 或者 16000 - Channel int `json:"channel"` //声道数,仅支持单声道,请填写固定值 1 - Cuid string `json:"cuid"` //用户唯一标识,用来区分用户,计算UV值。建议填写能区分用户的机器 MAC 地址或 IMEI 码,长度为60字符以内 - Token string `json:"token"` //开放平台获取到的开发者access_token - Lan string `json:"lan"` //语种选择,默认中文(zh)。 中文=zh、粤语=ct、英文=en,不区分大小写 - Speech string `json:"speech"` //真实的语音数据 ,需要进行base64 编码。与len参数连一起使用 - Len int `json:"len"` //原始语音长度,单位字节 -} - -//TextToSpeech 语音合成,将文字转换为语音 -func (vc *VoiceClient) TextToSpeech(txt string) ([]byte, error) { - - if len(txt) >= 1024 { - return []byte{}, ErrTextTooLong - } - if err := vc.Auth(); err != nil { - return []byte{}, err - } - if vc.TTSConfig == nil { - return []byte{}, ErrNoTTSConfig - } - itfcs, err := net.Interfaces() - if err != nil { - return []byte{}, err - } - mac := itfcs[0].HardwareAddr.String() - if mac == "" { - mac = randomStr(10) - } - params := req.Param{ - "tex": txt, //必填 合成的文本,使用UTF-8编码,请注意文本长度必须小于1024字节 - "lan": "zh", //必填 语言选择,目前只有中英文混合模式,填写固定值zh - "tok": vc.AccessToken, //必填 开放平台获取到的开发者access_token(见上面的“鉴权认证机制”段落) - "ctp": "1", //必填 客户端类型选择,web端填写固定值1 - "cuid": mac, //必填 用户唯一标识,用来区分用户,计算UV值。建议填写能区分用户的机器 MAC 地址或 IMEI 码,长度为60字符以内 - "spd": strconv.Itoa(vc.TTSConfig.SPD), //选填 语速,取值0-9,默认为5中语速 - "pit": strconv.Itoa(vc.TTSConfig.PIT), //选填 音调,取值0-9,默认为5中语调 - "vol": strconv.Itoa(vc.TTSConfig.VOL), //选填 音量,取值0-15,默认为5中音量 - "per": strconv.Itoa(vc.TTSConfig.PER), //选填 发音人选择, 0为普通女声,1为普通男声,3为情感合成-度逍遥,4为情感合成-度丫丫,默认为普通女声 - } - resp, err := req.Post(TTS_URL, params) - if err != nil { - return []byte{}, err - } - respHeader := resp.Response().Header - contentType, ok := respHeader["Content-Type"] - if !ok { - return []byte{}, errors.New("No Content-Type Set.") - } - if contentType[0] == "audio/mp3" { - respBody, err := ioutil.ReadAll(resp.Response().Body) - if err != nil { - return []byte{}, err - } - return respBody, nil - } else { - respStr, err := resp.ToString() - if err != nil { - return []byte{}, err - } - return []byte{}, errors.New("调用服务失败:" + respStr) - } -} - -//SpeechToText 语音识别,将语音翻译成文字 -func (vc *VoiceClient) SpeechToText(ap ASRParams) ([]string, error) { - if ap.Len > 8*10*MB { - return []string{}, errors.New("文件大小不能超过10M") - } - if err := vc.Auth(); err != nil { - return []string{}, err - } - ap.Token = vc.AccessToken - resp, err := req.Post(ASR_URL, req.Header{ - "Content-Type": "application/json", - }, req.BodyJSON(ap)) - if err != nil { - return []string{}, err - } - var rs ASRResponse - if err := resp.ToJSON(&rs); err != nil { - return []string{}, err - } - if !strings.Contains(rs.ERRMSG, "success") || rs.ERRNO != 0 { - return []string{}, errors.New("调用服务失败:" + rs.ERRMSG) - } - return rs.Result, nil -} - -func NewVoiceClient(ApiKey, secretKey string) *VoiceClient { - return &VoiceClient{ - Client: sdk.NewClient(ApiKey, secretKey), - TTSConfig: defaultTTSConfig, - } -} - -func randomStr(length int) string { - var baseStr = "qwertyuiopasdfghjklzxcvbnmQWERTYUIOPLKJHGFDSAZXCVBNM0123456789" - var bt []byte = make([]byte, length) - for i := 0; i < length; i++ { - k := rand.Intn(62) - bt[i] = baseStr[k] - } - return string(bt) -}