more MCTS improvements

This commit is contained in:
betasteward 2015-05-04 15:41:37 -04:00
parent 6403fff12b
commit c5e216ddbf
4 changed files with 199 additions and 12 deletions

View file

@ -92,6 +92,11 @@ public class ComputerPlayerMCTS extends ComputerPlayer implements Player {
if (!lastPhase.equals(game.getTurn().getValue(game.getTurnNum()))) { if (!lastPhase.equals(game.getTurn().getValue(game.getTurnNum()))) {
logList(game.getTurn().getValue(game.getTurnNum()) + name + " hand: ", new ArrayList(hand.getCards(game))); logList(game.getTurn().getValue(game.getTurnNum()) + name + " hand: ", new ArrayList(hand.getCards(game)));
lastPhase = game.getTurn().getValue(game.getTurnNum()); lastPhase = game.getTurn().getValue(game.getTurnNum());
if (MCTSNode.USE_ACTION_CACHE) {
int count = MCTSNode.cleanupCache(game.getTurnNum());
if (count > 0)
logger.info("Removed " + count + " cache entries");
}
} }
} }
game.getState().setPriorityPlayerId(playerId); game.getState().setPriorityPlayerId(playerId);
@ -113,7 +118,7 @@ public class ComputerPlayerMCTS extends ComputerPlayer implements Player {
Game sim = createMCTSGame(game); Game sim = createMCTSGame(game);
MCTSPlayer player = (MCTSPlayer) sim.getPlayer(playerId); MCTSPlayer player = (MCTSPlayer) sim.getPlayer(playerId);
player.setNextAction(action); player.setNextAction(action);
root = new MCTSNode(sim); root = new MCTSNode(playerId, sim);
} }
applyMCTS(game, action); applyMCTS(game, action);
root = root.bestChild(); root = root.bestChild();
@ -123,7 +128,7 @@ public class ComputerPlayerMCTS extends ComputerPlayer implements Player {
protected void getNextAction(Game game, NextAction nextAction) { protected void getNextAction(Game game, NextAction nextAction) {
if (root != null) { if (root != null) {
MCTSNode newRoot; MCTSNode newRoot;
newRoot = root.getMatchingState(game.getState().getValue(false, game)); newRoot = root.getMatchingState(game.getState().getValue(game, playerId));
if (newRoot != null) { if (newRoot != null) {
newRoot.emancipate(); newRoot.emancipate();
} }
@ -213,6 +218,7 @@ public class ComputerPlayerMCTS extends ComputerPlayer implements Player {
sb.append(game.getPermanent(attackerId).getName()).append(","); sb.append(game.getPermanent(attackerId).getName()).append(",");
} }
logger.info(sb.toString()); logger.info(sb.toString());
MCTSNode.logHitMiss();
} }
@Override @Override
@ -233,6 +239,7 @@ public class ComputerPlayerMCTS extends ComputerPlayer implements Player {
} }
} }
logger.info(sb.toString()); logger.info(sb.toString());
MCTSNode.logHitMiss();
} }
// @Override // @Override
@ -273,6 +280,7 @@ public class ComputerPlayerMCTS extends ComputerPlayer implements Player {
protected long totalThinkTime = 0; protected long totalThinkTime = 0;
protected long totalSimulations = 0; protected long totalSimulations = 0;
protected void applyMCTS(final Game game, final NextAction action) { protected void applyMCTS(final Game game, final NextAction action) {
int thinkTime = calculateThinkTime(game, action); int thinkTime = calculateThinkTime(game, action);
if (thinkTime > 0) { if (thinkTime > 0) {
@ -306,6 +314,7 @@ public class ComputerPlayerMCTS extends ComputerPlayer implements Player {
totalSimulations += simCount; totalSimulations += simCount;
logger.info("Player: " + name + " Simulated " + simCount + " games in " + thinkTime + " seconds - nodes in tree: " + root.size()); logger.info("Player: " + name + " Simulated " + simCount + " games in " + thinkTime + " seconds - nodes in tree: " + root.size());
logger.info("Total: Simulated " + totalSimulations + " games in " + totalThinkTime + " seconds - Average: " + totalSimulations/totalThinkTime); logger.info("Total: Simulated " + totalSimulations + " games in " + totalThinkTime + " seconds - Average: " + totalSimulations/totalThinkTime);
MCTSNode.logHitMiss();
} }
else { else {
long startTime = System.nanoTime(); long startTime = System.nanoTime();

View file

@ -48,7 +48,7 @@ public class MCTSExecutor implements Callable<Boolean> {
public MCTSExecutor(Game sim, UUID playerId, int thinkTime) { public MCTSExecutor(Game sim, UUID playerId, int thinkTime) {
this.playerId = playerId; this.playerId = playerId;
this.thinkTime = thinkTime; this.thinkTime = thinkTime;
root = new MCTSNode(sim); root = new MCTSNode(playerId, sim);
} }
@Override @Override

View file

@ -29,8 +29,11 @@ package mage.player.ai;
import java.util.ArrayDeque; import java.util.ArrayDeque;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Iterator;
import java.util.List; import java.util.List;
import java.util.Set;
import java.util.UUID; import java.util.UUID;
import java.util.concurrent.ConcurrentHashMap;
import mage.constants.PhaseStep; import mage.constants.PhaseStep;
import mage.constants.Zone; import mage.constants.Zone;
import mage.abilities.Ability; import mage.abilities.Ability;
@ -51,6 +54,7 @@ import org.apache.log4j.Logger;
*/ */
public class MCTSNode { public class MCTSNode {
public static final boolean USE_ACTION_CACHE = false;
private static final double selectionCoefficient = Math.sqrt(2.0); private static final double selectionCoefficient = Math.sqrt(2.0);
private static final double passRatioTolerance = 0.0; private static final double passRatioTolerance = 0.0;
private static final transient Logger logger = Logger.getLogger(MCTSNode.class); private static final transient Logger logger = Logger.getLogger(MCTSNode.class);
@ -63,14 +67,18 @@ public class MCTSNode {
private Game game; private Game game;
private Combat combat; private Combat combat;
private final String stateValue; private final String stateValue;
private final String fullStateValue;
private UUID playerId; private UUID playerId;
private boolean terminal = false; private boolean terminal = false;
private UUID targetPlayer;
private static int nodeCount; private static int nodeCount;
public MCTSNode(Game game) { public MCTSNode(UUID targetPlayer, Game game) {
this.targetPlayer = targetPlayer;
this.game = game; this.game = game;
this.stateValue = game.getState().getValue(false, game); this.stateValue = game.getState().getValue(game, targetPlayer);
this.fullStateValue = game.getState().getValue(true, game);
this.terminal = game.gameOver(null); this.terminal = game.gameOver(null);
setPlayer(); setPlayer();
nodeCount = 1; nodeCount = 1;
@ -78,8 +86,10 @@ public class MCTSNode {
} }
protected MCTSNode(MCTSNode parent, Game game, Ability action) { protected MCTSNode(MCTSNode parent, Game game, Ability action) {
this.targetPlayer = parent.targetPlayer;
this.game = game; this.game = game;
this.stateValue = game.getState().getValue(false, game); this.stateValue = game.getState().getValue(game, targetPlayer);
this.fullStateValue = game.getState().getValue(true, game);
this.terminal = game.gameOver(null); this.terminal = game.gameOver(null);
this.parent = parent; this.parent = parent;
this.action = action; this.action = action;
@ -89,9 +99,11 @@ public class MCTSNode {
} }
protected MCTSNode(MCTSNode parent, Game game, Combat combat) { protected MCTSNode(MCTSNode parent, Game game, Combat combat) {
this.targetPlayer = parent.targetPlayer;
this.game = game; this.game = game;
this.combat = combat; this.combat = combat;
this.stateValue = game.getState().getValue(false, game); this.stateValue = game.getState().getValue(game, targetPlayer);
this.fullStateValue = game.getState().getValue(true, game);
this.terminal = game.gameOver(null); this.terminal = game.gameOver(null);
this.parent = parent; this.parent = parent;
setPlayer(); setPlayer();
@ -144,7 +156,11 @@ public class MCTSNode {
switch (player.getNextAction()) { switch (player.getNextAction()) {
case PRIORITY: case PRIORITY:
// logger.info("Priority for player:" + player.getName() + " turn: " + game.getTurnNum() + " phase: " + game.getPhase().getType() + " step: " + game.getStep().getType()); // logger.info("Priority for player:" + player.getName() + " turn: " + game.getTurnNum() + " phase: " + game.getPhase().getType() + " step: " + game.getStep().getType());
List<Ability> abilities = player.getPlayableOptions(game); List<Ability> abilities;
if (!USE_ACTION_CACHE)
abilities = player.getPlayableOptions(game);
else
abilities = getPlayables(player, fullStateValue, game);
for (Ability ability: abilities) { for (Ability ability: abilities) {
Game sim = game.copy(); Game sim = game.copy();
// logger.info("expand " + ability.toString()); // logger.info("expand " + ability.toString());
@ -156,7 +172,11 @@ public class MCTSNode {
break; break;
case SELECT_ATTACKERS: case SELECT_ATTACKERS:
// logger.info("Select attackers:" + player.getName()); // logger.info("Select attackers:" + player.getName());
List<List<UUID>> attacks = player.getAttacks(game); List<List<UUID>> attacks;
if (!USE_ACTION_CACHE)
attacks = player.getAttacks(game);
else
attacks = getAttacks(player, fullStateValue, game);
UUID defenderId = game.getOpponents(player.getId()).iterator().next(); UUID defenderId = game.getOpponents(player.getId()).iterator().next();
for (List<UUID> attack: attacks) { for (List<UUID> attack: attacks) {
Game sim = game.copy(); Game sim = game.copy();
@ -170,7 +190,11 @@ public class MCTSNode {
break; break;
case SELECT_BLOCKERS: case SELECT_BLOCKERS:
// logger.info("Select blockers:" + player.getName()); // logger.info("Select blockers:" + player.getName());
List<List<List<UUID>>> blocks = player.getBlocks(game); List<List<List<UUID>>> blocks;
if (!USE_ACTION_CACHE)
blocks = player.getBlocks(game);
else
blocks = getBlocks(player, fullStateValue, game);
for (List<List<UUID>> block: blocks) { for (List<List<UUID>> block: blocks) {
Game sim = game.copy(); Game sim = game.copy();
MCTSPlayer simPlayer = (MCTSPlayer) sim.getPlayer(player.getId()); MCTSPlayer simPlayer = (MCTSPlayer) sim.getPlayer(player.getId());
@ -454,4 +478,99 @@ public class MCTSNode {
return num; return num;
} }
private static final ConcurrentHashMap<String, List<Ability>> playablesCache = new ConcurrentHashMap<>();
private static final ConcurrentHashMap<String, List<List<UUID>>> attacksCache = new ConcurrentHashMap<>();
private static final ConcurrentHashMap<String, List<List<List<UUID>>>> blocksCache = new ConcurrentHashMap<>();
private static long playablesHit = 0;
private static long playablesMiss = 0;
private static long attacksHit = 0;
private static long attacksMiss = 0;
private static long blocksHit = 0;
private static long blocksMiss = 0;
private static List<Ability> getPlayables(MCTSPlayer player, String state, Game game) {
if (playablesCache.containsKey(state)) {
playablesHit++;
return playablesCache.get(state);
}
else {
playablesMiss++;
List<Ability> abilities = player.getPlayableOptions(game);
playablesCache.put(state, abilities);
return abilities;
}
}
private static List<List<UUID>> getAttacks(MCTSPlayer player, String state, Game game) {
if (attacksCache.containsKey(state)) {
attacksHit++;
return attacksCache.get(state);
}
else {
attacksMiss++;
List<List<UUID>> attacks = player.getAttacks(game);
attacksCache.put(state, attacks);
return attacks;
}
}
private static List<List<List<UUID>>> getBlocks(MCTSPlayer player, String state, Game game) {
if (blocksCache.containsKey(state)) {
blocksHit++;
return blocksCache.get(state);
}
else {
blocksMiss++;
List<List<List<UUID>>> blocks = player.getBlocks(game);
blocksCache.put(state, blocks);
return blocks;
}
}
public static int cleanupCache(int turnNum) {
Set<String> playablesKeys = playablesCache.keySet();
Iterator<String> playablesIterator = playablesKeys.iterator();
int count = 0;
while(playablesIterator.hasNext()) {
String next = playablesIterator.next();
int cacheTurn = Integer.valueOf(next.split(":", 2)[0].substring(1));
if (cacheTurn < turnNum) {
playablesIterator.remove();
count++;
}
}
Set<String> attacksKeys = attacksCache.keySet();
Iterator<String> attacksIterator = attacksKeys.iterator();
while(attacksIterator.hasNext()) {
int cacheTurn = Integer.valueOf(attacksIterator.next().split(":", 2)[0].substring(1));
if (cacheTurn < turnNum) {
attacksIterator.remove();
count++;
}
}
Set<String> blocksKeys = blocksCache.keySet();
Iterator<String> blocksIterator = blocksKeys.iterator();
while(blocksIterator.hasNext()) {
int cacheTurn = Integer.valueOf(blocksIterator.next().split(":", 2)[0].substring(1));
if (cacheTurn < turnNum) {
blocksIterator.remove();
count++;
}
}
return count;
}
public static void logHitMiss() {
if (USE_ACTION_CACHE) {
StringBuilder sb = new StringBuilder();
sb.append("Playables Cache -- Hits: ").append(playablesHit).append(" Misses: ").append(playablesMiss).append("\n");
sb.append("Attacks Cache -- Hits: ").append(attacksHit).append(" Misses: ").append(attacksMiss).append("\n");
sb.append("Blocks Cache -- Hits: ").append(blocksHit).append(" Misses: ").append(blocksMiss).append("\n");
logger.info(sb.toString());
}
}
} }

View file

@ -226,7 +226,66 @@ public class GameState implements Serializable, Copyable<GameState> {
for (Player player: players.values()) { for (Player player: players.values()) {
sb.append("player").append(player.isPassed()).append(player.getLife()).append("hand"); sb.append("player").append(player.isPassed()).append(player.getLife()).append("hand");
if (useHidden && priorityPlayerId == player.getId()) { if (useHidden) {
sb.append(player.getHand().getValue(game));
}
else {
sb.append(player.getHand().size());
}
sb.append("library").append(player.getLibrary().size());
sb.append("graveyard");
sb.append(player.getGraveyard().getValue(game));
}
sb.append("permanents");
List<String> perms = new ArrayList<>();
for (Permanent permanent: battlefield.getAllPermanents()) {
perms.add(permanent.getValue());
}
Collections.sort(perms);
sb.append(perms);
sb.append("spells");
for (StackObject spell: stack) {
sb.append(spell.getControllerId()).append(spell.getName());
sb.append(spell.getStackAbility().toString());
for (Mode mode: spell.getStackAbility().getModes().values()) {
if (!mode.getTargets().isEmpty()) {
sb.append("targets");
for (Target target: mode.getTargets()) {
sb.append(target.getTargets());
}
}
if (!mode.getChoices().isEmpty()) {
sb.append("choices");
for (Choice choice: mode.getChoices()) {
sb.append(choice.getChoice());
}
}
}
}
for (ExileZone zone: exile.getExileZones()) {
sb.append("exile").append(zone.getName()).append(zone.getValue(game));
}
sb.append("combat");
for (CombatGroup group: combat.getGroups()) {
sb.append(group.getDefenderId()).append(group.getAttackers()).append(group.getBlockers());
}
return sb.toString();
}
public String getValue(Game game, UUID playerId) {
StringBuilder sb = threadLocalBuilder.get();
sb.append(turn.getValue(turnNum));
sb.append(activePlayerId).append(priorityPlayerId);
for (Player player: players.values()) {
sb.append("player").append(player.isPassed()).append(player.getLife()).append("hand");
if (playerId == player.getId()) {
sb.append(player.getHand().getValue(game)); sb.append(player.getHand().getValue(game));
} }
else { else {
@ -856,4 +915,4 @@ public class GameState implements Serializable, Copyable<GameState> {
} }
return copiedCard; return copiedCard;
} }
} }