single-threaded monte carlo + some fixes

This commit is contained in:
BetaSteward 2012-01-16 18:55:35 -05:00
parent 2e21b7197b
commit 377dd54fca
4 changed files with 106 additions and 51 deletions

View file

@ -52,7 +52,8 @@ import org.apache.log4j.Logger;
*/
public class ComputerPlayerMCTS extends ComputerPlayer<ComputerPlayerMCTS> implements Player {
private static final int thinkTimeRatioThreshold = 20;
private static final int THINK_MIN_RATIO = 20;
private static final int THINK_MAX_RATIO = 100;
protected transient MCTSNode root;
protected int maxThinkTime;
@ -112,8 +113,10 @@ public class ComputerPlayerMCTS extends ComputerPlayer<ComputerPlayerMCTS> imple
if (root != null) {
MCTSNode newRoot = null;
newRoot = root.getMatchingState(game.getState().getValue(false, game));
if (newRoot != null)
if (newRoot != null) {
newRoot.emancipate();
logger.info("choose action:" + newRoot.getAction() + " success ratio: " + newRoot.getWinRatio());
}
else
logger.info("unable to find matching state");
root = newRoot;
@ -258,28 +261,58 @@ public class ComputerPlayerMCTS extends ComputerPlayer<ComputerPlayerMCTS> imple
logger.info("applyMCTS - Thinking for " + (endTime - startTime)/1000000000.0 + "s");
if (thinkTime > 0) {
List<MCTSExecutor> tasks = new ArrayList<MCTSExecutor>();
for (int i = 0; i < cores; i++) {
Game sim = createMCTSGame(game);
MCTSPlayer player = (MCTSPlayer) sim.getPlayer(playerId);
player.setNextAction(action);
MCTSExecutor exec = new MCTSExecutor(sim, playerId, thinkTime);
tasks.add(exec);
// List<MCTSExecutor> tasks = new ArrayList<MCTSExecutor>();
// for (int i = 0; i < cores; i++) {
// Game sim = createMCTSGame(game);
// MCTSPlayer player = (MCTSPlayer) sim.getPlayer(playerId);
// player.setNextAction(action);
// MCTSExecutor exec = new MCTSExecutor(sim, playerId, thinkTime);
// tasks.add(exec);
// }
//
// try {
// pool.invokeAll(tasks);
// } catch (InterruptedException ex) {
// logger.warn("applyMCTS interrupted");
// }
//
// for (MCTSExecutor task: tasks) {
// root.merge(task.getRoot());
// task.clear();
// }
// tasks.clear();
MCTSNode current;
int simCount = 0;
while (true) {
long currentTime = System.nanoTime();
if (currentTime > endTime)
break;
current = root;
// Selection
while (!current.isLeaf()) {
current = current.select(this.playerId);
}
int result;
if (!current.isTerminal()) {
// Expansion
current.expand();
// Simulation
current = current.select(this.playerId);
result = current.simulate(this.playerId);
simCount++;
}
else {
result = current.isWinner(this.playerId)?1:-1;
}
// Backpropagation
current.backpropagate(result);
}
try {
pool.invokeAll(tasks);
} catch (InterruptedException ex) {
logger.warn("applyMCTS interrupted");
}
for (MCTSExecutor task: tasks) {
root.merge(task.getRoot());
task.clear();
}
tasks.clear();
logger.info("Created " + root.getNodeCount() + " nodes - size: " + root.size());
logger.info("Simulated " + simCount + " games - nodes in tree: " + root.size());
displayMemory();
}
@ -287,32 +320,38 @@ public class ComputerPlayerMCTS extends ComputerPlayer<ComputerPlayerMCTS> imple
return;
}
//try to ensure that there are at least 20 simulations per node at all times
//try to ensure that there are at least THINK_MIN_RATIO simulations per node at all times
private int calculateThinkTime(Game game, NextAction action) {
int thinkTime = 0;
int nodeSizeRatio = 0;
if (root.getNumChildren() > 0)
nodeSizeRatio = root.size() / root.getNumChildren();
nodeSizeRatio = root.getVisits() / root.getNumChildren();
logger.info("Ratio: " + nodeSizeRatio);
PhaseStep curStep = game.getStep().getType();
if (action == NextAction.SELECT_ATTACKERS || action == NextAction.SELECT_BLOCKERS) {
if (nodeSizeRatio < thinkTimeRatioThreshold) {
if (nodeSizeRatio < THINK_MIN_RATIO) {
thinkTime = maxThinkTime;
}
else if (nodeSizeRatio >= THINK_MAX_RATIO) {
thinkTime = 0;
}
else {
thinkTime = maxThinkTime / 2;
}
}
else if (game.getActivePlayerId().equals(playerId) && (curStep == PhaseStep.PRECOMBAT_MAIN || curStep == PhaseStep.POSTCOMBAT_MAIN)) {
if (nodeSizeRatio < thinkTimeRatioThreshold) {
else if (game.getActivePlayerId().equals(playerId) && (curStep == PhaseStep.PRECOMBAT_MAIN || curStep == PhaseStep.POSTCOMBAT_MAIN) && game.getStack().isEmpty()) {
if (nodeSizeRatio < THINK_MIN_RATIO) {
thinkTime = maxThinkTime;
}
else if (nodeSizeRatio >= THINK_MAX_RATIO) {
thinkTime = 0;
}
else {
thinkTime = maxThinkTime / 2;
}
}
else {
if (nodeSizeRatio < thinkTimeRatioThreshold) {
if (nodeSizeRatio < THINK_MIN_RATIO) {
thinkTime = maxThinkTime / 2;
}
else {

View file

@ -61,7 +61,6 @@ public class MCTSNode {
private MCTSNode parent;
private List<MCTSNode> children = new ArrayList<MCTSNode>();
private Ability action;
// private Combat combat;
private Game game;
private String stateValue;
private UUID playerId;
@ -88,7 +87,6 @@ public class MCTSNode {
this.game = game;
this.stateValue = game.getState().getValue(false, game);
this.parent = parent;
// this.combat = game.getCombat();
setPlayer();
nodeCount++;
}
@ -140,7 +138,6 @@ public class MCTSNode {
List<Ability> abilities = player.getPlayableOptions(game);
for (Ability ability: abilities) {
Game sim = game.copy();
// String simState = sim.getState().getValue(false, sim);
// logger.info("expand " + ability.toString());
MCTSPlayer simPlayer = (MCTSPlayer) sim.getPlayer(player.getId());
simPlayer.activateAbility((ActivatedAbility)ability, sim);
@ -154,7 +151,6 @@ public class MCTSNode {
UUID defenderId = game.getOpponents(player.getId()).iterator().next();
for (List<UUID> attack: attacks) {
Game sim = game.copy();
// String simState = sim.getState().getValue(false, sim);
MCTSPlayer simPlayer = (MCTSPlayer) sim.getPlayer(player.getId());
for (UUID attackerId: attack) {
simPlayer.declareAttacker(attackerId, defenderId, sim);
@ -168,7 +164,6 @@ public class MCTSNode {
List<List<List<UUID>>> blocks = player.getBlocks(game);
for (List<List<UUID>> block: blocks) {
Game sim = game.copy();
// String simState = sim.getState().getValue(false, sim);
MCTSPlayer simPlayer = (MCTSPlayer) sim.getPlayer(player.getId());
List<CombatGroup> groups = sim.getCombat().getGroups();
for (int i = 0; i < groups.size(); i++) {
@ -248,7 +243,10 @@ public class MCTSNode {
}
public void emancipate() {
this.parent = null;
if (parent != null) {
this.parent.children.remove(this);
this.parent = null;
}
}
public Ability getAction() {
@ -275,6 +273,16 @@ public class MCTSNode {
return stateValue;
}
public double getWinRatio() {
if (visits > 0)
return wins/(visits * 1.0);
return -1.0;
}
public int getVisits() {
return visits;
}
/**
* Copies game and replaces all players in copy with simulated players
* Shuffles each players library so that there is no knowledge of its order
@ -289,26 +297,37 @@ public class MCTSNode {
Player origPlayer = game.getState().getPlayers().get(copyPlayer.getId()).copy();
SimulatedPlayerMCTS newPlayer = new SimulatedPlayerMCTS(copyPlayer.getId(), true);
newPlayer.restore(origPlayer);
if (!newPlayer.getId().equals(playerId)) {
int handSize = newPlayer.getHand().size();
newPlayer.getLibrary().addAll(newPlayer.getHand().getCards(sim), sim);
newPlayer.getHand().clear();
newPlayer.getLibrary().shuffle();
for (int i = 0; i < handSize; i++) {
Card card = newPlayer.getLibrary().removeFromTop(sim);
sim.setZone(card.getId(), Zone.HAND);
newPlayer.getHand().add(card);
}
}
else {
newPlayer.getLibrary().shuffle();
}
sim.getState().getPlayers().put(copyPlayer.getId(), newPlayer);
}
randomizePlayers(sim, playerId);
sim.setSimulation(true);
return sim;
}
/*
* Shuffles each players library so that there is no knowledge of its order
* Swaps all other players hands with random cards from the library so that
* there is no knowledge of what cards are in opponents hands
*/
protected void randomizePlayers(Game game, UUID playerId) {
for (Player player: game.getState().getPlayers().values()) {
if (!player.getId().equals(playerId)) {
int handSize = player.getHand().size();
player.getLibrary().addAll(player.getHand().getCards(game), game);
player.getHand().clear();
player.getLibrary().shuffle();
for (int i = 0; i < handSize; i++) {
Card card = player.getLibrary().removeFromTop(game);
game.setZone(card.getId(), Zone.HAND);
player.getHand().add(card);
}
}
else {
player.getLibrary().shuffle();
}
}
}
public boolean isTerminal() {
return game.isGameOver();
}

View file

@ -28,16 +28,13 @@
package mage.player.ai;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.UUID;
import mage.abilities.Ability;
import mage.abilities.SpellAbility;
import mage.abilities.common.PassAbility;
import mage.abilities.costs.mana.GenericManaCost;
import mage.game.Game;
import mage.game.combat.Combat;
import mage.game.permanent.Permanent;
import org.apache.log4j.Logger;