第五期【百度大脑新品体验】风格转换与拉伸恢复
busyboxs 发布于2019-08 浏览:1480 回复:0
0
收藏
最后编辑于2022-04

本文将借助百度AI开放平台的图像处理的 API,使用 python 和 Flask 实现一个 web 界面。本文将对图像风格转换和拉伸图像内容进行说明,接下来将对接口以及实现过程进行介绍。

源码地址: https://github.com/busyboxs/image-process-flask

接口介绍

图像风格转换

  • 请求 url : https://aip.baidubce.com/rest/2.0/image-process/v1/style_trans
  • 请求参数
  • 返回参数:
  • 返回示例:
    {
        "log_id": "6876747463538438254",
        "image": "处理后图片的Base64编码"
    }

拉伸图像恢复

  • 请求 url : https://aip.baidubce.com/rest/2.0/image-process/v1/stretch_restore
  • 请求参数:
  • 返回参数:
  • 返回示例:
    {
        "log_id": "6876747463538438254",
        "image": "处理后图片的Base64编码"
    }

代码实现

整个代码的文件目录如下图所示:

  • images:该文件夹用于存放测试图像,保存于官方实例;
  • static/css:该文件存放web界面的样式文件;
  • static/fonts: 该文件夹存放的字体用于美化界面;
  • static/images: 该文件夹主要用于存放通过界面上传的图像;
  • static/js: 该文件存放一些模板的界面交互文件;
  • templates: 该文件主要用于存放 html 文件,每个html代表一个页面;
  • base.py: 调用各种api的一个基本类;
  • image_process.py: 调用图像处理接口类;
  • run.py: flask主要代码运行。

由于我也是一个前端小白,所以对于前端的内容不做详细说明。接下来对其中的部分主要代码进行说明。

调用 api 的基类

该基类的代码如下

import requests
import time
import json
import base64
from urllib.parse import urlencode
from urllib.parse import quote
from urllib.parse import urlparse


class Base(object):

    __access_token_url = 'https://aip.baidubce.com/oauth/2.0/token'

    def __init__(self, client_id, client_secret):
        self._client_id = client_id
        self._client_secret = client_secret
        self._auth_obj = {}
        self.__connect_timeout = 60.0
        self.__socket_timeout = 60.0

    def _auth(self, refresh=False):
        if not refresh:
            tm = self._auth_obj.get("time", 0) + int(self._auth_obj.get('expires_in', 0)) - 30
            if tm > int(time.time()):
                return self._auth_obj

        obj = requests.get(self.__access_token_url,
                           verify=False,
                           params={
                               'grant_type': 'client_credentials',
                               'client_id': self._client_id,
                               'client_secret': self._client_secret},
                           timeout=(
                               self.__connect_timeout,
                               self.__socket_timeout)).json()

        obj['time'] = int(time.time())
        self._auth_obj = obj
        return obj

    def _request(self, url, data, headers=None):
        try:
            auth_obj = self._auth()
            params = self._get_params(auth_obj)
            headers = {'Content-Type': 'application/x-www-form-urlencoded'}
            response = requests.post(url=url,
                                     data=data,
                                     params=params,
                                     verify=False,
                                     headers=headers,
                                     timeout=(
                                         self.__connect_timeout,
                                         self.__socket_timeout,
                                     ))
            obj = json.loads(response.content.decode())

            if obj.get('error_code', '') == 110:
                auth_obj = self._auth(True)
                params = self._get_params(auth_obj)
                response = requests.post(url=url,
                                         data=data,
                                         params=params,
                                         verify=False,
                                         headers=headers,
                                         timeout=(
                                             self.__connect_timeout,
                                             self.__socket_timeout,
                                         ))
                obj = json.loads(response.content.decode())
        except (requests.exceptions.ReadTimeout, requests.exceptions.ConnectTimeout) as e:
            return {
                'error_code': 'SDK108',
                'error_msg': 'connection or read data timeout',
            }

        return obj

    def _get_params(self, auth_obj):
        params = {'access_token': auth_obj['access_token']}
        return params

首先是初始化

def __init__(self, client_id, client_secret):
        self._client_id = client_id
        self._client_secret = client_secret
        self._auth_obj = {}
        self.__connect_timeout = 60.0
        self.__socket_timeout = 60.0

这里传入了参数 client_id 和 client_secret,这两个参数的具体内容为我们创建的应用的API Key 和 Secret Key,可以从控制台获取。

_auth_obj 用于存放获取 access token 的内容,另外还添加了获取 access token 时的时间,其作用主要是用于判断access token是否过期。

然后是鉴权函数

    def _auth(self, refresh=False):
        if not refresh:
            tm = self._auth_obj.get("time", 0) + int(self._auth_obj.get('expires_in', 0)) - 30
            if tm > int(time.time()):
                return self._auth_obj

        obj = requests.get(self.__access_token_url,
                           verify=False,
                           params={
                               'grant_type': 'client_credentials',
                               'client_id': self._client_id,
                               'client_secret': self._client_secret},
                           timeout=(
                               self.__connect_timeout,
                               self.__socket_timeout)).json()

        obj['time'] = int(time.time())
        self._auth_obj = obj
        return obj

该函数首先计算当前的 _auth_obj 中的 access token是否过期,如果未过期直接返回;否则重新获取相关内容。

最后是调用接口的函数

def _request(self, url, data, headers=None):
        try:
            auth_obj = self._auth()
            params = self._get_params(auth_obj)
            headers = {'Content-Type': 'application/x-www-form-urlencoded'}
            response = requests.post(url=url,
                                     data=data,
                                     params=params,
                                     verify=False,
                                     headers=headers,
                                     timeout=(
                                         self.__connect_timeout,
                                         self.__socket_timeout,
                                     ))
            obj = json.loads(response.content.decode())

            if obj.get('error_code', '') == 110:
                auth_obj = self._auth(True)
                params = self._get_params(auth_obj)
                response = requests.post(url=url,
                                         data=data,
                                         params=params,
                                         verify=False,
                                         headers=headers,
                                         timeout=(
                                             self.__connect_timeout,
                                             self.__socket_timeout,
                                         ))
                obj = json.loads(response.content.decode())
        except (requests.exceptions.ReadTimeout, requests.exceptions.ConnectTimeout) as e:
            return {
                'error_code': 'SDK108',
                'error_msg': 'connection or read data timeout',
            }

        return obj

该函数的 url 参数表示接口的请求 url,data参数表示接口请求参数,然后是通过post进行调用。

图像处理接口类

代码如下:

from base import Base
from base import base64

client_id = 'your API Key'
client_secret = 'your Secret Key'


class ImageProcess(Base):
    __process_url = 'https://aip.baidubce.com/rest/2.0/image-process/v1/'
    __image_quality_enhance_url = __process_url + 'image_quality_enhance'
    __dehaze_url = __process_url + 'dehaze'
    __contrast_enhance_url = __process_url + 'contrast_enhance'
    __colorize_url = __process_url + 'colourize'
    __stretch_restore_url = __process_url + 'stretch_restore'
    __style_trans_url = __process_url + 'style_trans'

    def image_quality_enhance(self, image, options=None):
        options = options or {}
        data = {}
        data['image'] = base64.b64encode(image).decode()
        data.update(options)
        return self._request(self.__image_quality_enhance_url, data)

    def dehaze(self, image, options=None):
        options = options or {}
        data = {}
        data['image'] = base64.b64encode(image).decode()
        data.update(options)
        return self._request(self.__dehaze_url, data)

    def contrast_enhance(self, image, options=None):
        options = options or {}
        data = {}
        data['image'] = base64.b64encode(image).decode()
        data.update(options)
        return self._request(self.__contrast_enhance_url, data)

    def colorize(self, image, options=None):
        options = options or {}
        data = {}
        data['image'] = base64.b64encode(image).decode()
        data.update(options)
        return self._request(self.__colorize_url, data)

    def stretch_restore(self, image, options=None):
        options = options or {}
        data = {}
        data['image'] = base64.b64encode(image).decode()
        data.update(options)
        return self._request(self.__stretch_restore_url, data)

    def style_trans(self, image, options=None):
        options = options or {}
        data = {}
        data['image'] = base64.b64encode(image).decode()
        data.update(options)
        return self._request(self.__style_trans_url, data)

这6个函数接口的调用形式是一样的,image参数为输入的图像,options是除图像外的其他请求参数,其格式为 dict。如果不是为了区分也可以写成一个函数。

主程序代码

首先看一下界面展示图

从图中可以看到需要实现两个主要的功能,一个是图像上传,一个是图像检测。

图像上传的代码如下:

@app.route('/', methods=['POST', 'GET'])  # 添加路由
def upload():
    filename = '3efe5cc3e397933216ed48f99ad43e02.png'
    if request.method == 'POST':
        file = request.files['file']
        if not (file and allowed_file(file.filename)):
            return jsonify({"error": 1001, "msg": "请检查上传的图片类型,仅限于png、PNG、jpg、JPG、bmp"})
        base_path = os.path.dirname(__file__)  # 当前文件所在路径
        upload_path = os.path.join(base_path, 'static/images')
        if file and allowed_file(file.filename):
            filename = secure_filename(file.filename)
            file.save(os.path.join(upload_path, filename))
        return redirect(url_for('show_button', img_name=filename))
    return render_template('pages/upload_img.html', img_name=filename)

该函数通过Post 表单获取上传的图像的路径,然后将图像保存到本地,然后在 show_button.html 页面上进行展示。

图像检测的代码如下:

@app.route('/show_button/?', methods=['POST', 'GET'])  # 添加路由
def show_button(img_name):
    if request.method == 'POST':
        process_type = request.form.get('process_type')
        img = './static/images/' + img_name
        with open(img, 'rb') as f:
            img_data = f.read()
        if process_type == "1":
            res = ip_obj.image_quality_enhance(img_data)
            res_base64 = res['image']
        elif process_type == "2":
            res = ip_obj.dehaze(img_data)
            res_base64 = res['image']
        elif process_type == "3":
            res = ip_obj.contrast_enhance(img_data)
            res_base64 = res['image']
        elif process_type == "4":
            res = ip_obj.colorize(img_data)
            res_base64 = res['image']
        elif process_type == "5":
            res = ip_obj.stretch_restore(img_data)
            res_base64 = res['image']
        elif process_type == "6":
            audio_con = request.form.get('inlineRadioOptions')
            options = {'option': audio_con}
            res = ip_obj.style_trans(img_data, options)
            res_base64 = res['image']

        return render_template('pages/img_process.html',
                               img_name=img_name,
                               img_base64=res_base64,
                               pro_type=process_type,
                               op_dict=PROCESS_TYPE)

    return render_template('pages/show_button.html', img_name=img_name, op_dict=PROCESS_TYPE)

图像处理的类型是通过一个下拉菜单来进行选择的,因此通过Post表单获取所选择的处理类型,然后调用相应的接口函数对上传的图像进行处理,处理结果在 img_process.html页面进行展示。

结果展示

  1. 图像拉伸恢复
  2. 图像风格转换-卡通
  3. 图像风格转换-素描

体验感受
个人比较喜欢图像风格转换这类应用,但是目前只支持卡通和素描两种,文档中提到了之后会上线油画,比较期待。希望之后还能再推出两张图像的风格转换,将一张图像的风格用于另一张图像上。

收藏
点赞
0
个赞
TOP
切换版块