让AI帮你玩游戏(二)训练模型并获取结果_用ai训练玩游戏-程序员宅基地

技术标签: python  机器学习  深度学习  Ai盘游戏  人工智能  游戏  

让AI帮你玩游戏


书接上文
我们在上一期介绍了我们整体的思路以及创建训练数据,搭建ssd-resnet50模型,现在我们开始训练我们的模型,并用我们训练好的模型获得浮标所在的坐标。

 

让AI帮你玩游戏(二)训练模型并获取结果

训练模型

没啥好说的直接上代码,这里的代码我进行了注释,大概解释了每段代码都在干什么。

class_id = 1
class_name = 'drift'
num_classes = 1
num_boxes = 1
batch_size = 4
learning_rate = 0.01
num_batches = 100  # 这里是训练步数,数量太大会过拟合,效果反而不好,针对我们样本数这里设为100足够
# 我们只选择模型的top layers变量进行训练,而不是整个模型,我们用少量样本训练模型会有过拟合的现象,不过我们也不是造原子弹,所以无所谓拉
trainable_variables = detection_model.trainable_variables
to_fine_tune = []
prefixes_to_train = [
    'WeightSharedConvolutionalBoxPredictor/WeightSharedConvolutionalBoxHead',
    'WeightSharedConvolutionalBoxPredictor/WeightSharedConvolutionalClassHead']
for var in trainable_variables:
    if any([var.name.startswith(prefix) for prefix in prefixes_to_train]):
        to_fine_tune.append(var)


# 为单个训练步骤设置 forward + backward 

def get_model_train_step_function(model, optimizer, vars_to_fine_tune):
    """这个函数是用来获取训练step的."""
   @tf.function
    def train_step_fn(image_tensors,
                      groundtruth_boxes_list,
                      groundtruth_classes_list):
        shapes = tf.constant(batch_size * [[im_width, im_height, 3]], dtype=tf.int32)
        model.provide_groundtruth(
            groundtruth_boxes_list=groundtruth_boxes_list,
            groundtruth_classes_list=groundtruth_classes_list)
        with tf.GradientTape() as tape:
            preprocessed_images = tf.concat(
                [detection_model.preprocess(image_tensor)[0]
                 for image_tensor in image_tensors], axis=0)
            prediction_dict = model.predict(preprocessed_images, shapes)
            losses_dict = model.loss(prediction_dict, shapes)
            total_loss = losses_dict['Loss/localization_loss'] + losses_dict['Loss/classification_loss']
            gradients = tape.gradient(total_loss, vars_to_fine_tune)
            optimizer.apply_gradients(zip(gradients, vars_to_fine_tune))
        return total_loss
    return train_step_fn

# 采用SGD方法进行优化
optimizer_ = tf.keras.optimizers.SGD(learning_rate=learning_rate, momentum=0.9)
train_step_fn_ = get_model_train_step_function(
                                            detection_model,
                                            optimizer_,
                                            to_fine_tune)

# 从这里开始训练模型
for idx in range(num_batches):
    all_keys = list(range(len(train_images_np)))
    random.shuffle(all_keys)
    example_keys = all_keys[:batch_size]
   gt_boxes_list = [gt_box_tensors[key] for key in example_keys]
    gt_classes_list = [gt_classes_one_hot_tensors[key] for key in example_keys]
    im_tensors = [train_image_tensors[key] for key in example_keys]
    total_loss_ = train_step_fn_(im_tensors, gt_boxes_list, gt_classes_list)
    if idx % 10 == 0:
        print('batch ' + str(idx) + ' of ' + str(num_batches) + ', loss=' + str(total_loss_.numpy()), flush=True)
# 训练结束!

在这里插入图片描述
训练结果如图,可见我们的loss在训练100次就达到了0.01,已经很低了,可以认为收敛了!

用训练好的Model获取坐标

首先将我们定义一个函数将图片转换为array

from six import BytesIO
from PIL import Image
from object_detection.utils import visualization_utils as viz_utils
def load_image_into_numpy_array(path):
    img_data = tf.io.gfile.GFile(path, 'rb').read()
    image_ = Image.open(BytesIO(img_data))
    (width, height) = image_.size
    return np.array(image_.getdata()).reshape(
        (height, width, 3)).astype(np.uint8)

path就是图片的路径,没啥好说的。
下面我们来定义一个将我们获取到的坐标画到图片上的函数

def plot_detections(image_np,
                    boxes,
                    classes,
                    scores,
                    category_index,
                    figsize=(12, 16),
                    image_name=None):
    image_np_with_annotations = image_np.copy()
    viz_utils.visualize_boxes_and_labels_on_image_array(
        image_np_with_annotations,
        boxes,
        classes,
        scores,
        category_index,
        use_normalized_coordinates=True,
        min_score_thresh=0.6)
    if image_name:
        plt.imsave(image_name, image_np_with_annotations)
    else:
        plt.imshow(image_np_with_annotations)

定义预测函数,用来获取我们模型的检测结果


@tf.function
def detect(tensor_):
    preprocessed_image, shapes_ = detection_model.preprocess(tensor_)
    prediction_dict_ = detection_model.predict(preprocessed_image, shapes_)
    return detection_model.postprocess(prediction_dict_, shapes)

现在我们来获取我们想要的结果!


label_id_offset = 1
for i in range(len(test_images_np)):
    print(i)
    input_tensor = tf.convert_to_tensor(test_images_np[i],
                                        dtype=tf.float32)
    detections = detect(input_tensor)
    plot_detections(
        test_images_np[i][0],
        detections['detection_boxes'][0].numpy(),
        detections['detection_classes'][0].numpy().astype(np.uint32) + label_id_offset,
        detections['detection_scores'][0].numpy(),
        category_index,
        figsize=(15, 20),
        image_name="results/gif_frame_" + ('%02d' % i) + ".jpg")

我们来看一下效果:
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
结果显而易见,在这么复杂的环境下识别率很高,其中有个跟花重叠的都可以识别,其实超过60%就可以认为成功识别,可见我们的模型准确性还是可以的!

实现我们想要的功能

整个工程中最难的部分已经解决,现在我们可以用得到的坐标,计算出浮标的中心点位置,然后就可以进行各种骚操作了,需要提醒的是,当第一次运行的时候需要把从图像获取到的坐标跟游戏里鼠标的位置进行校正,只要你游戏里视角发生变化都需要进行校正!切记!

关于声音的判断,有无数种方法,你可以计算获取到声音的特征值,做比较,相似就收杆,也可以判断声卡发出声音的响度来判断(我比较懒,用的后者,嘿嘿),你也可以搭建一个神经网络模型来训练,让它可以听出来这个声音(声音检测)等等。。。

我们有坐标了,有收杆时机的判断了,就可以实现主逻辑了,至于主逻辑嘛我就不多做介绍了,既然前边难度那么大的都可以轻松解决,这点事情肯定不会难到您的。至于功能的实现可以用windows API啊,HOOK啊,驱动啊等等想用什么用什么,群魔乱舞。。。噗,不对不对,是八仙过海各显神通。。。好了本教程到此结束!

最后再次声明:本文所涉及的内容仅用作学习研究,严禁用于非法用途,游戏中违反相关协议可能会使你失去你的游戏账号!

教程完结,祝大家生活愉快!(所有代码已上传QQ群文件,进群请加VX:JTSMJJ)

版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。
本文链接:https://blog.csdn.net/pp2351/article/details/109302828

智能推荐

Java+springboot+MYSQL牙科诊所预约系统75174-计算机毕业设计项目选题推荐(赠源码)-程序员宅基地

文章浏览阅读485次,点赞13次,收藏9次。本智慧综合管理是针对目前牙科诊所预约系统的实际需求,从实际工作出发,对过去的牙科诊所预约系统存在的问题进行分析,结合计算机系统的结构、概念、模型、原理、方法,在计算机各种优势的情况下,采用目前最流行的B/S结构、java技术MySQL数据库设计并实现的。本牙科诊所预约系统主要包括登录模块的实现、系统模块、管理员模块、医生模块、用户模块等多个模块。它帮助牙科诊所预约系统实现了信息化、网络化,通过测试,实现了系统设计目标,相比传统的管理模式,本系统合理的利用了网络数据资源,有效的减少了牙科诊所预约系统的经济投入

【pandas】踩了merge操作的一个坑_file "pandas/_libs/join.pyx", line 104, in pandas.-程序员宅基地

文章浏览阅读5.2k次。最近一个上线半年的爬虫挂了,错误信息如题,挂在了一个dataframe的merge操作上。仔细看了看源代码,这个merge操作非常简单,目的只是想看两个数据集中date(%Y-%m-%d %H-%M-%S)交集,然后再决定下一步的操作。首先怀疑数据量随着时间推移变得很大,得分块操作了。但是查了一下当下的数据量,两个dataframe都只有几十万行,数据量并不是很大,反而Jenkins serve..._file "pandas/_libs/join.pyx", line 104, in pandas._libs.join.left_outer_join memoryerror

java代码在图片上画框_java ffmpage图片 画框-程序员宅基地

文章浏览阅读1w次,点赞2次,收藏8次。有时候对于一些截图需要通过画框重点显示,用java代码在图片上画框的代码如下:package imagetest;import java.awt.Color;import java.awt.Graphics;import java.awt.image.BufferedImage;import java.io.File;import java.io.FileInputSt_java ffmpage图片 画框

android建ftp服务器,Android 快速搭建FTP服务器的方法-程序员宅基地

文章浏览阅读2.2k次。一、概述打开你的手机,找到文件管理->分类->远程管理,点击启动服务,这样大家可以在局域网内使用电脑访问你手机上的文件了,当然你也可以设置账号和密码,防止“小人”共享你手机上的资源-.-,那如果自己动手,该如何实现这个小功能呢?二、实现1、导入相关的jar包,并在build.gradle添加相应的依赖,如图。2.创建服务配置文件在values文件夹下新建的xml文件,方便在代码中的引用..._android ftp 服务器

探索Awesome ML Demos with iOS:让机器学习触手可及-程序员宅基地

文章浏览阅读389次,点赞3次,收藏6次。探索Awesome ML Demos with iOS:让机器学习触手可及项目地址:https://gitcode.com/tucan9389/awesome-ml-demos-with-ios在这个数字化的时代,机器学习(ML)已经成为了软件开发的重要组成部分,尤其在移动应用中。Awesome ML Demos with iOS 是一个精心策划的GitHub项目,它汇集了众多以iOS平台为基...

biubiubiu坐地铁 期望dp_n个座位地铁坐下人数的期望值-程序员宅基地

文章浏览阅读1.2k次。链接:https://ac.nowcoder.com/acm/contest/642/M来源:牛客网题目描述BiuBiuBiu 每次出去玩都要去坐地铁,BiuBiuBiu 观察到,当地铁上人比较少的时候,大家都会选择那些与其他人不相邻的座位,现在地铁上有 n 个座位排成一排,1 号座位与 2 号相邻,n 号座位与 n-1 号相邻,除了 1 号与 n 号座位,任意 i 号座位都与 i-..._n个座位地铁坐下人数的期望值

随便推点

服务器响应为 5.7.1,执行发送邮件Send方法时,报错:邮箱不可用。 服务器响应为: 5.7.1 Unable to relay for [email protected]程序员宅基地

文章浏览阅读612次。php常用方法总结/** * created by Tina * time 2015-1-6 10:31 * textarea中传入字符串的处理,返回数组,传入的字符串以换行分割; * 拆分,压缩空格,去除空值,去重复 ...bzoj 1097 [POI2007]旅游景点atr(最短路,状压DP)[题意] 给定一个n点m边的无向图,要求1开始n结束而且顺序经..._事务失败 服务器响应为5.7.1

Android Studio快捷键以及导入Eclipse项目_寻求升级帮助,emu i com斜杠emot,ion do w-程序员宅基地

文章浏览阅读508次。Android Studio常用快捷键1. Ctrl+D: 集合了复制和粘贴两个操作,如果有选中的部分就复制选中的部分,并在选中部分的后面粘贴出来,如果没有选中的部分,就复制光标所在的行,并在此行的下面粘贴出来。2. Ctrl+空格: 输入代码时按此组合键会列出与之相匹配的类、方法名、成员变量等,起智能提示的作用。在编辑XML文件一样有用。3. Ctrl+向下箭头 或Ctr_寻求升级帮助,emu i com斜杠emot,ion do w

区块链共识算法综述论文阅读笔记:A Review on Consensus Algorithm of Blockchain_区块链共识算法论文-程序员宅基地

文章浏览阅读397次。本文是区块链共识算法的综述论文“A Review on Consensus Algorithm of Blockchain”的阅读笔记,论文对区块链的共识算法进行了全面的描述,但是受限于时代因素,里面的一些内容有一定错误或过时,例如PoS+PoW被分类为PoS、认为PoW具有无限的可扩展性等。_区块链共识算法论文

C#中[WebMethod]的用法,aspx、ashx、asmx-程序员宅基地

文章浏览阅读361次。在.net 3.5的情况下前台JQuery做Ajax的时候,服务器端(1)可以调用aspx.cs 中声明带有[WebMehtod]的public static 的方法(不需要自己手动添加web.config的配置)(2)可以调用 *.asmx (web服务) 里面加了[webmethod]的方法(不能写静态,写静态就调用不到了)需要在asmx里面 去掉 [System.Web.Scri..._asmx webmethod语法

Unity编辑器扩展: GUILayout、EditorGUILayout 控件整理_editorguilayout.popup-程序员宅基地

文章浏览阅读1.4w次,点赞15次,收藏76次。GUILayoutGUILayoutOption基本每个控件方法都有一个可选参数是GUILayoutOption[] Options 这是一个可以控制组件大小之类的选项,在GUILayout类中共有8个。GUILayout.Height()GUILayout.Width()GUILayout.MaxHeight()GUILayout..._editorguilayout.popup

Android学习--Fragment-程序员宅基地

文章浏览阅读874次。学习目标:提示:这里可以添加学习目标例如: 一周掌握 Java 入门知识学习内容:提示:这里可以添加要学的内容例如:搭建 Java 开发环境掌握 Java 基本语法掌握条件语句掌握循环语句学习时间:提示:这里可以添加计划学习的时间例如:周一至周五晚上 7 点—晚上9点周六上午 9 点-上午 11 点周日下午 3 点-下午 6 点学习产出:提示:这里统计学习计划的总量例如: 技术笔记 2 遍 CSDN 技术博客 3 篇 习的 vlog 视频 1

推荐文章

热门文章

相关标签