开放能力
开发平台
行业应用
生态合作
开发与教学
资讯 社区 控制台
技术能力
语音技术
图像技术
文字识别
人脸与人体识别
视频技术
AR与VR
自然语言处理
知识图谱
数据智能
场景方案
部署方案
行业应用
智能教育
智能医疗
智能零售
智能工业
企业服务
智能政务
智能农业
信息服务
智能园区
智能硬件
BML 全功能AI开发平台

    Tensorflow2.3.0代码规范

    Tensorflow 2.3.0代码规范

    基于Tensorflow2.3.0框架的MNIST图像分类,训练数据集tf_train_data2.zip点击这里下载。
    如下所示是其超参搜索任务中一个超参数组合的训练代码,代码会通过argparse模块接受在平台中填写的信息,请保持一致。

    tensorflow2.3_autosearch.py示例代码

    # -*- coding:utf-8 -*-
    """ tensorflow2 train demo """
    import tensorflow as tf
    import os
    import numpy as np
    import time
    import argparse
    from rudder_autosearch.sdk.amaas_tools import AMaasTools
    
    def parse_arg():
        """parse arguments"""
        parser = argparse.ArgumentParser(description='tensorflow2.3 mnist Example')
        parser.add_argument('--train_dir', type=str, default='./train_data',
                            help='input data dir for training (default: ./train_data)')
        parser.add_argument('--test_dir', type=str, default='./test_data',
                            help='input data dir for test (default: ./test_data)')
        parser.add_argument('--output_dir', type=str, default='./output',
                            help='output dir for auto_search job (default: ./output)')
        parser.add_argument('--job_id', type=str, default="job-1234",
                            help='auto_search job id (default: "job-1234")')
        parser.add_argument('--trial_id', type=str, default="0-0",
                            help='auto_search id of a single trial (default: "0-0")')
        parser.add_argument('--metric', type=str, default="acc",
                            help='evaluation metric of the model')
        parser.add_argument('--data_sampling_scale', type=float, default=1.0,
                            help='sampling ratio of the data (default: 1.0)')
        parser.add_argument('--batch_size', type=int, default=100,
                            help='number of images input in an iteration (default: 100)')
        parser.add_argument('--lr', type=float, default=0.001,
                            help='learning rate of the training (default: 0.001)')
        parser.add_argument('--epoch', type=int, default=5,
                            help='number of epochs to train (default: 5)')
        args = parser.parse_args()
        args.output_dir = os.path.join(args.output_dir, args.job_id, args.trial_id)
        if not os.path.exists(args.output_dir):
            os.makedirs(args.output_dir)
        print("job_id: {}, trial_id: {}".format(args.job_id, args.trial_id))
        return args
    
    def load_data(data_sampling_scale):
        """ load data """
        mnist = tf.keras.datasets.mnist
        work_path = os.getcwd()
        (x_train, y_train), (x_test, y_test) = mnist.load_data('%s/train_data/mnist.npz' % work_path)
        # sample training data
        np.random.seed(0)
        sample_data_num = int(data_sampling_scale * len(x_train))
        idx = np.arange(len(x_train))
        np.random.shuffle(idx)
        x_train, y_train = x_train[0:sample_data_num], y_train[0:sample_data_num]
        x_train, x_test = x_train / 255.0, x_test / 255.0
        return (x_train, x_test), (y_train, y_test)
    
    def Model(learning_rate):
        """Model"""
        model = tf.keras.models.Sequential([
            tf.keras.layers.Flatten(input_shape=(28, 28)),
            tf.keras.layers.Dense(128, activation='relu'),
            tf.keras.layers.Dropout(0.2),
            tf.keras.layers.Dense(10, activation='softmax')
        ])
        model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate),
                      loss='sparse_categorical_crossentropy',
                      metrics=['accuracy'])
        return model
    
    def evaluate(model, x_test, y_test):
        """evaluate"""
        loss, acc = model.evaluate(x_test, y_test, verbose=2)
        print("accuracy: %f" % acc)
        return acc
    
    def report_final(args, metric):
        """report_final_result"""
        # 结果上报sdk
        amaas_tools = AMaasTools(args.job_id, args.trial_id)
        metric_dict = {args.metric: metric}
        for i in range(3):
            flag, ret_msg = amaas_tools.report_final_result(metric=metric_dict,
                                                                 export_model_path=args.output_dir,
                                                                 checkpoint_path="")
            print("End Report, metric:{}, ret_msg:{}".format(metric, ret_msg))
            if flag:
                break
            time.sleep(1)
        assert flag, "Report final result to manager failed! Please check whether manager'address or manager'status " \
                     "is ok! "
    
    def main():
        """main"""
        # 获取参数
        args = parse_arg()
        # 加载数据集
        (x_train, x_test), (y_train, y_test) = load_data(args.data_sampling_scale)
        # 模型定义
        model = Model(args.lr)
        # 模型训练
        model.fit(x_train, y_train, epochs=args.epoch, batch_size=args.batch_size)
        # 模型保存
        model.save(args.output_dir)
        # 模型评估
        acc = evaluate(model, x_test, y_test)
        # 上报结果
        report_final(args, metric=acc)
    
    if __name__ == '__main__':
        main()

    示例代码对应的yaml配置如下,请保持格式一致

    pwo_search_demo.yml示例内容

    #搜索算法参数
    search_strategy:
      algo: PARTICLE_SEARCH #搜索策略:粒子群算法
      params:
        population_num: 8 #种群个体数量 | [1,10] int类型
        round: 10 #迭代轮数  |[5,50] int类型
        inertia_weight: 0.5 # 惯性权重  |(0,1] float类型
        global_acceleration: 1.5 #全局加速度 |(0,4] float类型
        local_acceleration: 1.5 #个体加速度  |(0,4] float类型
    
    #单次训练时数据的采样比例,单位%
    data_sampling_scale: 100 #|(0,100] int类型
    
    #评价指标参数
    metrics:
      name: acc #评价指标  | 任意字符串 str类型
      goal: MAXIMIZE #最大值/最小值 | str类型   MAXIMIZE or MINIMIZE   必须为这两个之一(也即支持大写)
      expected_value: 100 #早停标准值,评价指标超过该值则结束整个超参搜索,单位%  |无限制 int类型
    
    #搜索参数空间
    search_space:
      batch_size:
        htype: choice
        value: [100, 200, 300, 400, 500, 600]
      lr:
        htype: loguniform
        value: [0.0001, 0.9]
      epoch:
        htype: choice
        value: [5, 10, 12]
    上一篇
    yaml文件编写规范
    下一篇
    TensorFlow 1.13.2代码规范