一个简易的五子棋bot

在两年前…我就想搞一个五子棋bot来愉快玩耍。然后在这么久之后,我不想写题然后颓废…我才搞出一个棋力一般的五子棋bot。(智障选手生无可恋)

当我第一次想写这个bot的时候,我就想肯定价值越大的位置越好嘛!于是直接判可以成5的位置,计算分数,然后贪最大的。当时我甚至没有怎么了解OI,所以根本不知道这个东西被称为估价函数。

结果呢?跑起来被我自己吊着打……于是当时我弃疗了,留下了一个vb写的对局棋盘框架。

究其原因,我的估价函数是非常不优秀……因为我直接取了做了估价函数;并且统计方式是,从一个点向上下左右扩展,这导致本来成不了5的点都会被统计进入答案。

一年后(去年)我重新看代码的时候发现了这个问题,于是我改变了估价方式,从统计一个点变成了对于所有成5的可能全部判一遍,并且给每种可能成5情况的所有点同时加上价值。这样做,虽然似乎不能准确地判断出死活(例如死4仍然有活4一半的价值),但是已经可以比较清晰的体现棋子位置的价值了。同时经过手动退火(……),找到了一些看起来比较优秀的估价值。跑起来,终于能下赢我自己啦!同时百度上面的智障五子棋小游戏也下平了。

然后拿去班上给同学玩,结果被吊着打。(……)

首先,这样做的话估价函数的强弱完全决定了这个bot棋力的强弱,但是事实上找到的估价函数并没有非常优秀。其次,我的统计方式本来就不太科学,例如上面的判死活等等。

然而最重要的是,贪心短视地令人嗤之以鼻。

Now Turn : #
+ 010203040506070809101112131415
01                              
02                              
03                              
04                              
05                     O        
06               #   #          
07     O ? O   O O #   #        
08       ? # O # # # O          
09         O # # O O            
10       # O O   #              
11                              
12                              
13                              
14                              
15

考虑如上的局面。此时如果是纯贪心,那么通常来说都会选择(7, 5),因为按照估计函数的计算(7, 5)更加优秀;然而事实上,如果不走(8, 5),那么白棋走(8, 5)就已经是杀棋了。

于是那时的我想起,有往后算XX步的说法,恍然大悟:哦原来可以使用搜索。然后我写了一个搜三步的程序……不过因为我当时对局面的判断取得是总和而不是最大值,效果非常差。

然后今年我又想颓废,于是我就重新看了看自己的程序,发现简直是个逗逼……

重新写了一个基于以下策略的五子棋bot:
1. Min-Max搜索,加上Alpha-Beta剪枝;
2. 为方便Alpha-Beta剪枝,每步对决策按估价函数排序,选择尽量大的先搜索;
3. 估价函数应设计为一定要能够体现出输赢,并且尽量反映局势。注意到五子棋的输赢和全局关系不大,只要一个地方非常优就可以赢了,所以这里我采用的是max(我)-max(敌)作为反应局势。

看起来非常靠谱……事实上,我调调写写搞了很久,结果效果和贪心差不多。

直到一天半后,我才发现:

哇,我botzone上面的黑白读反了。

改了以后瞬间在botzone上登顶了。(UPD: 已经被打爆啦 QAQ)

(2017.06.12 17:05)

然而实际上这个bot棋力是不够的……甚至我随便下了一个手机五子棋都下不过;一些同学也可以轻易手爆这个bot。

假定估价函数足够准确,那么采用Min-Max搜索出的结果应该是非常好的。
然而现实总是没有那么令人满意……总是。很难找到一种准确的评价局面的方式,原因和贪心的错误类似。这里如果没法准确评价局面,那么Min-Max搜出来的最好结果也不是最好的。

更正确的姿势是用蒙特卡洛树搜索,即基于一定程度随机化的搜索。这种搜索通常来说都有一个特性,以随机化为基础,并且只看游戏终局带来的影响。即,不到游戏结束,不停止往更深的地方搜索。这样虽然是随机化的,但是效果会好很多。Hob让我写uct……然而我觉得我不能继续颓废在bot上面了,于是就暂时先写到这里了。

#include <iostream>
#include <cstdio>
#include <cstring>
#include <cmath>
#include <algorithm>
#include <queue>
#include <ctime>
#include <cstdlib>
#include <iomanip>
using namespace std;

inline int min(int a, int b) {return a < b ? a : b;}
inline int max(int a, int b) {return a > b ? a : b;}
inline double min(double a, double b) {return a < b ? a : b;}
inline double max(double a, double b) {return a > b ? a : b;}

#define Lim 15
const double eps = 1e-4;

struct point {
    int x,y;
    double v;
    point () {}
    point (int a, int b, double c) {x = a, y = b, v = c;}
    bool operator < (point a) const {
        return v>a.v;
    }
} tmp[23][233];

double Src[Lim + 1];

int G[Lim + 1][Lim + 1],cnt[3];
double t1[Lim + 1][Lim + 1], t2[Lim + 1][Lim + 1], (*att)[Lim + 1] = t1, (*dfn)[Lim + 1] = t2;

string myData;

char tt[Lim * Lim];

void printBoard() {
    myData += "+ 010203040506070809101112131415\n";
    for (int i = 1; i <= Lim; i++) {
        sprintf(tt, "%02d",i); myData += tt;
        for (int j = 1;j <= Lim; j++) {
            myData += ' ';
            switch (G[i][j]) {
                case 0: myData += ' '; break;
                case 1: myData += '#'; break;
                case 2: myData += 'O'; break;
            }
        }
        myData += '\n';
    }
}

inline void check(int a,int b) {
    memset(att,0,sizeof(double) * (Lim + 1) * (Lim + 1));
    memset(dfn,0,sizeof(double) * (Lim + 1) * (Lim + 1));
    for (int i=1;i<=Lim;i++) {
        cnt[1] = cnt[2] = 0;
        for (int k = 0; k <= 3; k++) cnt[G[i][1 + k]]++;
        for (int j=1;j<=Lim-4;j++) {
            cnt[G[i][j+4]]++;
            if (cnt[b]==0) for (int k=0;k<=4;k++) att[i][j+k] += Src[cnt[a]];
            if (cnt[a]==0) for (int k=0;k<=4;k++) dfn[i][j+k] += Src[cnt[b]];
            cnt[G[i][j]]--;
        }
    }
    for (int j=1;j<=Lim;j++) {
        cnt[1] = cnt[2] = 0;
        for (int k = 0; k <= 3; k++) cnt[G[1 + k][j]]++;
        for (int i=1;i<=Lim-4;i++) {
            cnt[G[i+4][j]]++;
            if (cnt[b]==0) for (int k=0;k<=4;k++) att[i+k][j] += Src[cnt[a]];
            if (cnt[a]==0) for (int k=0;k<=4;k++) dfn[i+k][j] += Src[cnt[b]];
            cnt[G[i][j]]--;
        }
    }
    for (int i=1;i<=Lim-4;i++) {
        for (int j=1;j<=Lim-4;j++) {
            cnt[1]=cnt[2]=0;
            for (int k=0;k<=4;k++) cnt[G[i+k][j+k]]++;
            if (cnt[b]==0) for (int k=0;k<=4;k++) att[i+k][j+k] += Src[cnt[a]];
            if (cnt[a]==0) for (int k=0;k<=4;k++) dfn[i+k][j+k] += Src[cnt[b]];
        }
    }
    for (int i=5;i<=Lim;i++) {
        for (int j=1;j<=Lim-4;j++) {
            cnt[1]=cnt[2]=0;
            for (int k=0;k<=4;k++) {
                cnt[G[i-k][j+k]]++;
            }
            if (cnt[b]==0) for (int k=0;k<=4;k++) att[i-k][j+k] += Src[cnt[a]];
            if (cnt[a]==0) for (int k=0;k<=4;k++) dfn[i-k][j+k] += Src[cnt[b]];
        }
    }
}

inline void upd(int x, int y, int a, int b, int f) {
    for (int i = max(1, x - 4); i <= x && i + 4 <= Lim; i++) {
        cnt[1] = cnt[2]=0;
        for (int k=0;k<=4;k++) cnt[G[i+k][y]]++;
        if (cnt[b]==0) for (int k=0;k<=4;k++) att[i+k][y] += f * Src[cnt[a]];
        if (cnt[a]==0) for (int k=0;k<=4;k++) dfn[i+k][y] += f * Src[cnt[b]];
    }
    for (int j = max(1, y - 4); j <= y && j + 4 <= Lim; j++) {
        cnt[1] = cnt[2]=0;
        for (int k=0;k<=4;k++) cnt[G[x][j + k]]++;
        if (cnt[b]==0) for (int k=0;k<=4;k++) att[x][j + k] += f * Src[cnt[a]];
        if (cnt[a]==0) for (int k=0;k<=4;k++) dfn[x][j + k] += f * Src[cnt[b]];
    }   
    for (int d = min(4, min(x - 1, y - 1)), i = x - d, j = y - d;i <= Lim-4 && j <= Lim-4 && i <= x;i++, j++) {
        cnt[1]=cnt[2]=0;
        for (int k=0;k<=4;k++) cnt[G[i+k][j+k]]++;
        if (cnt[b]==0) for (int k=0;k<=4;k++) att[i+k][j+k] += f * Src[cnt[a]];
        if (cnt[a]==0) for (int k=0;k<=4;k++) dfn[i+k][j+k] += f * Src[cnt[b]];
    }
    for (int d = min(4, min(Lim - x, y - 1)), i = x + d, j = y - d;i >= 4 && j <= Lim-4 && j <= y;i--, j++) {
        cnt[1]=cnt[2]=0;
        for (int k=0;k<=4;k++) cnt[G[i-k][j+k]]++;
        if (cnt[b]==0) for (int k=0;k<=4;k++) att[i-k][j+k] += f * Src[cnt[a]];
        if (cnt[a]==0) for (int k=0;k<=4;k++) dfn[i-k][j+k] += f * Src[cnt[b]];
    }
}

inline void update(int x, int y, int a, int b, int c) {
    upd(x, y, a, b, -1);
    G[x][y] = c;
    upd(x, y, a, b, 1);
    swap(att, dfn);
}

const int X = 8;

double dfs(int a, int b, int dep, double alpha, double beta) {
//  printBoard();
    if (!dep) {
        double ma = 0, mb = 0;
        for (int i = 1; i <= 15; i++)
            for (int j = 1; j <= 15; j++) {
                ma = max(ma, att[i][j]);
                mb = max(mb, dfn[i][j]);
            }
        return ma - mb;
    }
    int tot = 0;
    for (int i = 1; i <= 15; i++)
        for (int j = 1; j <= 15; j++) {
            if (att[i][j] > 1e8) return 1e20;
            else if (dfn[i][j] > 1e8) return -1e20;
            if (!G[i][j]) {
                tmp[dep][++tot].x = i;
                tmp[dep][tot].y = j;
                tmp[dep][tot].v = att[i][j] + dfn[i][j];
            }
        }
    nth_element(tmp[dep] + 1, tmp[dep] + min(tot, X), tmp[dep] + tot + 1);
    sort(tmp[dep] + 1, tmp[dep] + min(tot, X) + 1);
    for (int i = 1; i <= min(tot, X); i++) {
        nth_element(tmp[dep] + 1, tmp[dep] + i, tmp[dep] + tot + 1);
        update(tmp[dep][i].x, tmp[dep][i].y, a, b, a);
        double val = -dfs(b, a, dep - 1, -beta - eps, -alpha - eps);
        update(tmp[dep][i].x, tmp[dep][i].y, b, a, 0);
        if (val >= beta) return val;
        alpha = max(alpha, val);
    }
    return alpha;
}

int now, ene, First;

point work() {
    int stp = min(X, cnt[1] * 3);
    check(now, ene);
    int tot = 0;
    for (int i = 1; i <= 15; i++)
        for (int j = 1; j <= 15; j++) {
            if (!G[i][j]) tmp[Lim][++tot] = point(i, j, att[i][j] + dfn[i][j]);
        }
    //printBoard();

    sort(tmp[Lim] + 1, tmp[Lim] + tot + 1);
    double ma = -1e233; int x = tmp[Lim][1].x, y = tmp[Lim][1].y;
    time_t st = clock();
    for (int i = 1; i <= min(11, tot) && (clock() - st) * 1. / CLOCKS_PER_SEC  < 800; i++) {
        if (tmp[Lim][i].v <= tmp[Lim][1].v / 10) break;
        update(tmp[Lim][i].x, tmp[Lim][i].y, now, ene, now);
        double pos = -dfs(ene, now, stp, -1e233, -ma - eps);
    //  myData += to_string(pos) + " -> " + to_string(tmp[Lim][i].x) +  " " + to_string(tmp[Lim][i].y) + "\n"; 
        update(tmp[Lim][i].x, tmp[Lim][i].y, ene, now, 0);
        if (pos - eps > ma) {
            ma = pos;
            x = tmp[Lim][i].x, y = tmp[Lim][i].y;
        }
    }
    return point(x, y, 0);
}
#ifdef _BOTZONE_ONLINE
#include "jsoncpp/json.h"

int tot = 0;

void placeAt(int x, int y, int c) {
    if (x < 0) return;
    G[x + 1][y + 1] = c; cnt[c]++;
}

Json::Reader reader;
Json::Value input;

int A, B;

void read() {
    string str;
    getline(cin, str);
    reader.parse(str, input); 
    int turnID = input["responses"].size();
    for (int i = 0; i < turnID; i++) {
        placeAt(input["requests"][i]["x"].asInt(), input["requests"][i]["y"].asInt(),2);
        placeAt(input["responses"][i]["x"].asInt(), input["responses"][i]["y"].asInt(),1);
    }
    placeAt(input["requests"][turnID]["x"].asInt(), input["requests"][turnID]["y"].asInt(),2);
    if (tot&1) {
        First=0;now=2;ene=1;
    }
    else {
        First=1;now=1;ene=2;
    }
    A = cnt[1], B = cnt[2];
    cnt[0] = Lim * Lim - cnt[1] - cnt[2];
}

Json::Value Position(int x,int y) {
    Json::Value action;
    action["x"] = x-1;
    action["y"] = y-1;
    return action;
}

void print(int x,int y) {
    Json::Value ret;
//  myData = "---- " + to_string(A + B) + " ----\n" + myData + "\n";
    ret["response"] = Position(x,y);
//  if (A + B < 3) ret["globaldata"] = myData;
//  else ret["globaldata"] = input["globaldata"].asString() + myData;
    Json::FastWriter writer;
    cout << writer.write(ret) << endl;
}

#else 
void read() {
    freopen("board.in","r",stdin);
    scanf("%d",&now);
    ene=3-now;
    //printf("[%d %d]\n", now, ene);
    for (int i=1;i<=15;i++) {
        for (int j=1;j<=15;j++) {
            int ch=getchar();
            while (ch<'0') ch=getchar();
            G[i][j]=ch-'0';
            cnt[G[i][j]]++;
        }
    }
    freopen("/dev/tty", "r", stdin);
}

void print(int x, int y) {
//  cout << myData << endl;
    freopen("decision.out","w",stdout);
    printf("%d %d\n", x, y);  
    fclose(stdin);
    fclose(stdout);
}
#endif

int main() {
    for (int i = 0; i < 5; i++) Src[i] = pow(10, i);
    Src[2] *= 2.501;
    Src[5] = 1e10;
    read();
    if (cnt[0] == Lim * Lim) {
        print((Lim + 1) / 2, (Lim + 1) / 2);
        return 0;
    }
    if (cnt[1]==cnt[2]) First=1;
    else First=0;
    point p = work();
    print(p.x, p.y);
    return 0;
}

发表评论

电子邮件地址不会被公开。 必填项已用*标注