From c5e216ddbf2e3871f0214e74b856f0804f776c08 Mon Sep 17 00:00:00 2001 From: betasteward Date: Mon, 4 May 2015 15:41:37 -0400 Subject: [PATCH] more MCTS improvements --- .../mage/player/ai/ComputerPlayerMCTS.java | 13 +- .../src/mage/player/ai/MCTSExecutor.java | 2 +- .../src/mage/player/ai/MCTSNode.java | 133 +++++++++++++++++- Mage/src/mage/game/GameState.java | 63 ++++++++- 4 files changed, 199 insertions(+), 12 deletions(-) diff --git a/Mage.Server.Plugins/Mage.Player.AIMCTS/src/mage/player/ai/ComputerPlayerMCTS.java b/Mage.Server.Plugins/Mage.Player.AIMCTS/src/mage/player/ai/ComputerPlayerMCTS.java index 239bdf3e55..c620bb6840 100644 --- a/Mage.Server.Plugins/Mage.Player.AIMCTS/src/mage/player/ai/ComputerPlayerMCTS.java +++ b/Mage.Server.Plugins/Mage.Player.AIMCTS/src/mage/player/ai/ComputerPlayerMCTS.java @@ -92,6 +92,11 @@ public class ComputerPlayerMCTS extends ComputerPlayer implements Player { if (!lastPhase.equals(game.getTurn().getValue(game.getTurnNum()))) { logList(game.getTurn().getValue(game.getTurnNum()) + name + " hand: ", new ArrayList(hand.getCards(game))); lastPhase = game.getTurn().getValue(game.getTurnNum()); + if (MCTSNode.USE_ACTION_CACHE) { + int count = MCTSNode.cleanupCache(game.getTurnNum()); + if (count > 0) + logger.info("Removed " + count + " cache entries"); + } } } game.getState().setPriorityPlayerId(playerId); @@ -113,7 +118,7 @@ public class ComputerPlayerMCTS extends ComputerPlayer implements Player { Game sim = createMCTSGame(game); MCTSPlayer player = (MCTSPlayer) sim.getPlayer(playerId); player.setNextAction(action); - root = new MCTSNode(sim); + root = new MCTSNode(playerId, sim); } applyMCTS(game, action); root = root.bestChild(); @@ -123,7 +128,7 @@ public class ComputerPlayerMCTS extends ComputerPlayer implements Player { protected void getNextAction(Game game, NextAction nextAction) { if (root != null) { MCTSNode newRoot; - newRoot = root.getMatchingState(game.getState().getValue(false, game)); + newRoot = root.getMatchingState(game.getState().getValue(game, playerId)); if (newRoot != null) { newRoot.emancipate(); } @@ -213,6 +218,7 @@ public class ComputerPlayerMCTS extends ComputerPlayer implements Player { sb.append(game.getPermanent(attackerId).getName()).append(","); } logger.info(sb.toString()); + MCTSNode.logHitMiss(); } @Override @@ -233,6 +239,7 @@ public class ComputerPlayerMCTS extends ComputerPlayer implements Player { } } logger.info(sb.toString()); + MCTSNode.logHitMiss(); } // @Override @@ -273,6 +280,7 @@ public class ComputerPlayerMCTS extends ComputerPlayer implements Player { protected long totalThinkTime = 0; protected long totalSimulations = 0; protected void applyMCTS(final Game game, final NextAction action) { + int thinkTime = calculateThinkTime(game, action); if (thinkTime > 0) { @@ -306,6 +314,7 @@ public class ComputerPlayerMCTS extends ComputerPlayer implements Player { totalSimulations += simCount; logger.info("Player: " + name + " Simulated " + simCount + " games in " + thinkTime + " seconds - nodes in tree: " + root.size()); logger.info("Total: Simulated " + totalSimulations + " games in " + totalThinkTime + " seconds - Average: " + totalSimulations/totalThinkTime); + MCTSNode.logHitMiss(); } else { long startTime = System.nanoTime(); diff --git a/Mage.Server.Plugins/Mage.Player.AIMCTS/src/mage/player/ai/MCTSExecutor.java b/Mage.Server.Plugins/Mage.Player.AIMCTS/src/mage/player/ai/MCTSExecutor.java index 1612876d57..1e12abdcc7 100644 --- a/Mage.Server.Plugins/Mage.Player.AIMCTS/src/mage/player/ai/MCTSExecutor.java +++ b/Mage.Server.Plugins/Mage.Player.AIMCTS/src/mage/player/ai/MCTSExecutor.java @@ -48,7 +48,7 @@ public class MCTSExecutor implements Callable { public MCTSExecutor(Game sim, UUID playerId, int thinkTime) { this.playerId = playerId; this.thinkTime = thinkTime; - root = new MCTSNode(sim); + root = new MCTSNode(playerId, sim); } @Override diff --git a/Mage.Server.Plugins/Mage.Player.AIMCTS/src/mage/player/ai/MCTSNode.java b/Mage.Server.Plugins/Mage.Player.AIMCTS/src/mage/player/ai/MCTSNode.java index 28f28644ab..eacf86be8b 100644 --- a/Mage.Server.Plugins/Mage.Player.AIMCTS/src/mage/player/ai/MCTSNode.java +++ b/Mage.Server.Plugins/Mage.Player.AIMCTS/src/mage/player/ai/MCTSNode.java @@ -29,8 +29,11 @@ package mage.player.ai; import java.util.ArrayDeque; import java.util.ArrayList; +import java.util.Iterator; import java.util.List; +import java.util.Set; import java.util.UUID; +import java.util.concurrent.ConcurrentHashMap; import mage.constants.PhaseStep; import mage.constants.Zone; import mage.abilities.Ability; @@ -51,6 +54,7 @@ import org.apache.log4j.Logger; */ public class MCTSNode { + public static final boolean USE_ACTION_CACHE = false; private static final double selectionCoefficient = Math.sqrt(2.0); private static final double passRatioTolerance = 0.0; private static final transient Logger logger = Logger.getLogger(MCTSNode.class); @@ -63,14 +67,18 @@ public class MCTSNode { private Game game; private Combat combat; private final String stateValue; + private final String fullStateValue; private UUID playerId; private boolean terminal = false; + private UUID targetPlayer; private static int nodeCount; - public MCTSNode(Game game) { + public MCTSNode(UUID targetPlayer, Game game) { + this.targetPlayer = targetPlayer; this.game = game; - this.stateValue = game.getState().getValue(false, game); + this.stateValue = game.getState().getValue(game, targetPlayer); + this.fullStateValue = game.getState().getValue(true, game); this.terminal = game.gameOver(null); setPlayer(); nodeCount = 1; @@ -78,8 +86,10 @@ public class MCTSNode { } protected MCTSNode(MCTSNode parent, Game game, Ability action) { + this.targetPlayer = parent.targetPlayer; this.game = game; - this.stateValue = game.getState().getValue(false, game); + this.stateValue = game.getState().getValue(game, targetPlayer); + this.fullStateValue = game.getState().getValue(true, game); this.terminal = game.gameOver(null); this.parent = parent; this.action = action; @@ -89,9 +99,11 @@ public class MCTSNode { } protected MCTSNode(MCTSNode parent, Game game, Combat combat) { + this.targetPlayer = parent.targetPlayer; this.game = game; this.combat = combat; - this.stateValue = game.getState().getValue(false, game); + this.stateValue = game.getState().getValue(game, targetPlayer); + this.fullStateValue = game.getState().getValue(true, game); this.terminal = game.gameOver(null); this.parent = parent; setPlayer(); @@ -144,7 +156,11 @@ public class MCTSNode { switch (player.getNextAction()) { case PRIORITY: // logger.info("Priority for player:" + player.getName() + " turn: " + game.getTurnNum() + " phase: " + game.getPhase().getType() + " step: " + game.getStep().getType()); - List abilities = player.getPlayableOptions(game); + List abilities; + if (!USE_ACTION_CACHE) + abilities = player.getPlayableOptions(game); + else + abilities = getPlayables(player, fullStateValue, game); for (Ability ability: abilities) { Game sim = game.copy(); // logger.info("expand " + ability.toString()); @@ -156,7 +172,11 @@ public class MCTSNode { break; case SELECT_ATTACKERS: // logger.info("Select attackers:" + player.getName()); - List> attacks = player.getAttacks(game); + List> attacks; + if (!USE_ACTION_CACHE) + attacks = player.getAttacks(game); + else + attacks = getAttacks(player, fullStateValue, game); UUID defenderId = game.getOpponents(player.getId()).iterator().next(); for (List attack: attacks) { Game sim = game.copy(); @@ -170,7 +190,11 @@ public class MCTSNode { break; case SELECT_BLOCKERS: // logger.info("Select blockers:" + player.getName()); - List>> blocks = player.getBlocks(game); + List>> blocks; + if (!USE_ACTION_CACHE) + blocks = player.getBlocks(game); + else + blocks = getBlocks(player, fullStateValue, game); for (List> block: blocks) { Game sim = game.copy(); MCTSPlayer simPlayer = (MCTSPlayer) sim.getPlayer(player.getId()); @@ -454,4 +478,99 @@ public class MCTSNode { return num; } + private static final ConcurrentHashMap> playablesCache = new ConcurrentHashMap<>(); + private static final ConcurrentHashMap>> attacksCache = new ConcurrentHashMap<>(); + private static final ConcurrentHashMap>>> blocksCache = new ConcurrentHashMap<>(); + + private static long playablesHit = 0; + private static long playablesMiss = 0; + private static long attacksHit = 0; + private static long attacksMiss = 0; + private static long blocksHit = 0; + private static long blocksMiss = 0; + + private static List getPlayables(MCTSPlayer player, String state, Game game) { + if (playablesCache.containsKey(state)) { + playablesHit++; + return playablesCache.get(state); + } + else { + playablesMiss++; + List abilities = player.getPlayableOptions(game); + playablesCache.put(state, abilities); + return abilities; + } + } + + private static List> getAttacks(MCTSPlayer player, String state, Game game) { + if (attacksCache.containsKey(state)) { + attacksHit++; + return attacksCache.get(state); + } + else { + attacksMiss++; + List> attacks = player.getAttacks(game); + attacksCache.put(state, attacks); + return attacks; + } + } + + private static List>> getBlocks(MCTSPlayer player, String state, Game game) { + if (blocksCache.containsKey(state)) { + blocksHit++; + return blocksCache.get(state); + } + else { + blocksMiss++; + List>> blocks = player.getBlocks(game); + blocksCache.put(state, blocks); + return blocks; + } + } + + public static int cleanupCache(int turnNum) { + Set playablesKeys = playablesCache.keySet(); + Iterator playablesIterator = playablesKeys.iterator(); + int count = 0; + while(playablesIterator.hasNext()) { + String next = playablesIterator.next(); + int cacheTurn = Integer.valueOf(next.split(":", 2)[0].substring(1)); + if (cacheTurn < turnNum) { + playablesIterator.remove(); + count++; + } + } + + Set attacksKeys = attacksCache.keySet(); + Iterator attacksIterator = attacksKeys.iterator(); + while(attacksIterator.hasNext()) { + int cacheTurn = Integer.valueOf(attacksIterator.next().split(":", 2)[0].substring(1)); + if (cacheTurn < turnNum) { + attacksIterator.remove(); + count++; + } + } + + Set blocksKeys = blocksCache.keySet(); + Iterator blocksIterator = blocksKeys.iterator(); + while(blocksIterator.hasNext()) { + int cacheTurn = Integer.valueOf(blocksIterator.next().split(":", 2)[0].substring(1)); + if (cacheTurn < turnNum) { + blocksIterator.remove(); + count++; + } + } + + return count; + } + + public static void logHitMiss() { + if (USE_ACTION_CACHE) { + StringBuilder sb = new StringBuilder(); + sb.append("Playables Cache -- Hits: ").append(playablesHit).append(" Misses: ").append(playablesMiss).append("\n"); + sb.append("Attacks Cache -- Hits: ").append(attacksHit).append(" Misses: ").append(attacksMiss).append("\n"); + sb.append("Blocks Cache -- Hits: ").append(blocksHit).append(" Misses: ").append(blocksMiss).append("\n"); + logger.info(sb.toString()); + } + } } diff --git a/Mage/src/mage/game/GameState.java b/Mage/src/mage/game/GameState.java index c7684a38f2..2283989944 100644 --- a/Mage/src/mage/game/GameState.java +++ b/Mage/src/mage/game/GameState.java @@ -226,7 +226,66 @@ public class GameState implements Serializable, Copyable { for (Player player: players.values()) { sb.append("player").append(player.isPassed()).append(player.getLife()).append("hand"); - if (useHidden && priorityPlayerId == player.getId()) { + if (useHidden) { + sb.append(player.getHand().getValue(game)); + } + else { + sb.append(player.getHand().size()); + } + sb.append("library").append(player.getLibrary().size()); + sb.append("graveyard"); + sb.append(player.getGraveyard().getValue(game)); + } + + sb.append("permanents"); + List perms = new ArrayList<>(); + for (Permanent permanent: battlefield.getAllPermanents()) { + perms.add(permanent.getValue()); + } + Collections.sort(perms); + sb.append(perms); + + sb.append("spells"); + for (StackObject spell: stack) { + sb.append(spell.getControllerId()).append(spell.getName()); + sb.append(spell.getStackAbility().toString()); + for (Mode mode: spell.getStackAbility().getModes().values()) { + if (!mode.getTargets().isEmpty()) { + sb.append("targets"); + for (Target target: mode.getTargets()) { + sb.append(target.getTargets()); + } + } + if (!mode.getChoices().isEmpty()) { + sb.append("choices"); + for (Choice choice: mode.getChoices()) { + sb.append(choice.getChoice()); + } + } + } + } + + for (ExileZone zone: exile.getExileZones()) { + sb.append("exile").append(zone.getName()).append(zone.getValue(game)); + } + + sb.append("combat"); + for (CombatGroup group: combat.getGroups()) { + sb.append(group.getDefenderId()).append(group.getAttackers()).append(group.getBlockers()); + } + + return sb.toString(); + } + + public String getValue(Game game, UUID playerId) { + StringBuilder sb = threadLocalBuilder.get(); + + sb.append(turn.getValue(turnNum)); + sb.append(activePlayerId).append(priorityPlayerId); + + for (Player player: players.values()) { + sb.append("player").append(player.isPassed()).append(player.getLife()).append("hand"); + if (playerId == player.getId()) { sb.append(player.getHand().getValue(game)); } else { @@ -856,4 +915,4 @@ public class GameState implements Serializable, Copyable { } return copiedCard; } -} + }