From 9209e4331064f2e258818cf2452d5635e9ea640c Mon Sep 17 00:00:00 2001 From: betasteward Date: Thu, 30 Apr 2015 11:33:22 -0400 Subject: [PATCH] Updates to MCTS AI --- .../mage/player/ai/ComputerPlayerMCTS.java | 77 +++++++++++++------ .../src/mage/player/ai/MCTSExecutor.java | 15 ++-- .../src/mage/player/ai/MCTSNode.java | 31 +++++--- 3 files changed, 80 insertions(+), 43 deletions(-) diff --git a/Mage.Server.Plugins/Mage.Player.AIMCTS/src/mage/player/ai/ComputerPlayerMCTS.java b/Mage.Server.Plugins/Mage.Player.AIMCTS/src/mage/player/ai/ComputerPlayerMCTS.java index 56959a416d..239bdf3e55 100644 --- a/Mage.Server.Plugins/Mage.Player.AIMCTS/src/mage/player/ai/ComputerPlayerMCTS.java +++ b/Mage.Server.Plugins/Mage.Player.AIMCTS/src/mage/player/ai/ComputerPlayerMCTS.java @@ -46,6 +46,8 @@ import java.util.List; import java.util.UUID; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; +import java.util.concurrent.RejectedExecutionException; +import java.util.concurrent.TimeUnit; /** * @@ -60,16 +62,14 @@ public class ComputerPlayerMCTS extends ComputerPlayer implements Player { protected transient MCTSNode root; protected int maxThinkTime; - private static final transient Logger logger = Logger.getLogger(ComputerPlayerMCTS.class); - private transient ExecutorService pool; - private int cores; + private static final transient Logger logger = Logger.getLogger(ComputerPlayerMCTS.class); + private int poolSize; public ComputerPlayerMCTS(String name, RangeOfInfluence range, int skill) { super(name, range); human = false; maxThinkTime = (int) (skill * THINK_TIME_MULTIPLIER); - cores = Runtime.getRuntime().availableProcessors(); - pool = Executors.newFixedThreadPool(cores); + poolSize = Runtime.getRuntime().availableProcessors(); } protected ComputerPlayerMCTS(UUID id) { @@ -85,19 +85,26 @@ public class ComputerPlayerMCTS extends ComputerPlayer implements Player { return new ComputerPlayerMCTS(this); } + protected String lastPhase = ""; @Override public boolean priority(Game game) { - if (game.getStep().getType() == PhaseStep.PRECOMBAT_MAIN) - logList("computer player " + name + " hand: ", new ArrayList(hand.getCards(game))); + if (game.getStep().getType() == PhaseStep.UPKEEP) { + if (!lastPhase.equals(game.getTurn().getValue(game.getTurnNum()))) { + logList(game.getTurn().getValue(game.getTurnNum()) + name + " hand: ", new ArrayList(hand.getCards(game))); + lastPhase = game.getTurn().getValue(game.getTurnNum()); + } + } game.getState().setPriorityPlayerId(playerId); game.firePriorityEvent(playerId); getNextAction(game, NextAction.PRIORITY); - Ability ability = root.getAction(); + Ability ability = root.getAction(); if (ability == null) logger.fatal("null ability"); activateAbility((ActivatedAbility)ability, game); if (ability instanceof PassAbility) return false; + logLife(game); + logger.info("choose action:" + root.getAction() + " success ratio: " + root.getWinRatio()); return true; } @@ -119,7 +126,6 @@ public class ComputerPlayerMCTS extends ComputerPlayer implements Player { newRoot = root.getMatchingState(game.getState().getValue(false, game)); if (newRoot != null) { newRoot.emancipate(); - logger.info("choose action:" + newRoot.getAction() + " success ratio: " + newRoot.getWinRatio()); } else logger.info("unable to find matching state"); @@ -197,26 +203,36 @@ public class ComputerPlayerMCTS extends ComputerPlayer implements Player { @Override public void selectAttackers(Game game, UUID attackingPlayerId) { + StringBuilder sb = new StringBuilder(); + sb.append(game.getTurn().getValue(game.getTurnNum())).append(" player ").append(name).append(" attacking with: "); getNextAction(game, NextAction.SELECT_ATTACKERS); Combat combat = root.getCombat(); UUID opponentId = game.getCombat().getDefenders().iterator().next(); for (UUID attackerId: combat.getAttackers()) { this.declareAttacker(attackerId, opponentId, game, false); + sb.append(game.getPermanent(attackerId).getName()).append(","); } + logger.info(sb.toString()); } @Override public void selectBlockers(Game game, UUID defendingPlayerId) { + StringBuilder sb = new StringBuilder(); + sb.append(game.getTurn().getValue(game.getTurnNum())).append(" player ").append(name).append(" blocking: "); getNextAction(game, NextAction.SELECT_BLOCKERS); Combat combat = root.getCombat(); List groups = game.getCombat().getGroups(); for (int i = 0; i < groups.size(); i++) { if (i < combat.getGroups().size()) { + sb.append(game.getPermanent(groups.get(i).getAttackers().get(0)).getName()).append(" with: "); for (UUID blockerId: combat.getGroups().get(i).getBlockers()) { this.declareBlocker(this.getId(), blockerId, groups.get(i).getAttackers().get(0), game); + sb.append(game.getPermanent(blockerId).getName()).append(","); } + sb.append("|"); } } + logger.info(sb.toString()); } // @Override @@ -254,17 +270,16 @@ public class ComputerPlayerMCTS extends ComputerPlayer implements Player { // throw new UnsupportedOperationException("Not supported yet."); // } + protected long totalThinkTime = 0; + protected long totalSimulations = 0; protected void applyMCTS(final Game game, final NextAction action) { int thinkTime = calculateThinkTime(game, action); - - long startTime = System.nanoTime(); - long endTime = startTime + (thinkTime * 1000000000l); - logger.info("applyMCTS - Thinking for " + (endTime - startTime)/1000000000.0 + "s"); - + if (thinkTime > 0) { if (USE_MULTIPLE_THREADS) { - List tasks = new ArrayList(); - for (int i = 0; i < cores; i++) { + ExecutorService pool = Executors.newFixedThreadPool(poolSize); + List tasks = new ArrayList<>(); + for (int i = 0; i < poolSize; i++) { Game sim = createMCTSGame(game); MCTSPlayer player = (MCTSPlayer) sim.getPlayer(playerId); player.setNextAction(action); @@ -273,18 +288,28 @@ public class ComputerPlayerMCTS extends ComputerPlayer implements Player { } try { - pool.invokeAll(tasks); - } catch (InterruptedException ex) { + pool.invokeAll(tasks, thinkTime, TimeUnit.SECONDS); + pool.awaitTermination(1, TimeUnit.SECONDS); + pool.shutdownNow(); + } catch (InterruptedException | RejectedExecutionException ex) { logger.warn("applyMCTS interrupted"); } - + + int simCount = 0; for (MCTSExecutor task: tasks) { + simCount += task.getSimCount(); root.merge(task.getRoot()); task.clear(); } tasks.clear(); + totalThinkTime += thinkTime; + totalSimulations += simCount; + logger.info("Player: " + name + " Simulated " + simCount + " games in " + thinkTime + " seconds - nodes in tree: " + root.size()); + logger.info("Total: Simulated " + totalSimulations + " games in " + totalThinkTime + " seconds - Average: " + totalSimulations/totalThinkTime); } else { + long startTime = System.nanoTime(); + long endTime = startTime + (thinkTime * 1000000000l); MCTSNode current; int simCount = 0; while (true) { @@ -316,10 +341,9 @@ public class ComputerPlayerMCTS extends ComputerPlayer implements Player { } logger.info("Simulated " + simCount + " games - nodes in tree: " + root.size()); } - displayMemory(); +// displayMemory(); } -// root.print(1); } //try to ensure that there are at least THINK_MIN_RATIO simulations per node at all times @@ -328,7 +352,7 @@ public class ComputerPlayerMCTS extends ComputerPlayer implements Player { int nodeSizeRatio = 0; if (root.getNumChildren() > 0) nodeSizeRatio = root.getVisits() / root.getNumChildren(); - logger.info("Ratio: " + nodeSizeRatio); +// logger.info("Ratio: " + nodeSizeRatio); PhaseStep curStep = game.getStep().getType(); if (action == NextAction.SELECT_ATTACKERS || action == NextAction.SELECT_BLOCKERS) { if (nodeSizeRatio < THINK_MIN_RATIO) { @@ -410,5 +434,14 @@ public class ComputerPlayerMCTS extends ComputerPlayer implements Player { logger.info("Max heap size: " + heapMaxSize/mb + " Heap size: " + heapSize/mb + " Used: " + heapUsedSize/mb); } + + protected void logLife(Game game) { + StringBuilder sb = new StringBuilder(); + sb.append(game.getTurn().getValue(game.getTurnNum())); + for (Player player: game.getPlayers().values()) { + sb.append("[player ").append(player.getName()).append(":").append(player.getLife()).append("]"); + } + logger.info(sb.toString()); + } } diff --git a/Mage.Server.Plugins/Mage.Player.AIMCTS/src/mage/player/ai/MCTSExecutor.java b/Mage.Server.Plugins/Mage.Player.AIMCTS/src/mage/player/ai/MCTSExecutor.java index fbe716aa37..1612876d57 100644 --- a/Mage.Server.Plugins/Mage.Player.AIMCTS/src/mage/player/ai/MCTSExecutor.java +++ b/Mage.Server.Plugins/Mage.Player.AIMCTS/src/mage/player/ai/MCTSExecutor.java @@ -41,8 +41,9 @@ public class MCTSExecutor implements Callable { protected transient MCTSNode root; protected int thinkTime; protected UUID playerId; + protected int simCount; - private static final transient Logger logger = Logger.getLogger(ComputerPlayerMCTS.class); + private static final transient Logger logger = Logger.getLogger(ComputerPlayerMCTS.class); public MCTSExecutor(Game sim, UUID playerId, int thinkTime) { this.playerId = playerId; @@ -52,16 +53,11 @@ public class MCTSExecutor implements Callable { @Override public Boolean call() { - int simCount = 0; - long startTime = System.nanoTime(); - long endTime = startTime + (thinkTime * 1000000000l); + simCount = 0; MCTSNode current; while (true) { - long currentTime = System.nanoTime(); - if (currentTime > endTime) - break; current = root; // Selection @@ -92,8 +88,6 @@ public class MCTSExecutor implements Callable { // Backpropagation current.backpropagate(result); } - logger.info("Simulated " + simCount + " games - nodes in tree: " + root.size()); - return true; } public MCTSNode getRoot() { @@ -104,4 +98,7 @@ public class MCTSExecutor implements Callable { root = null; } + public int getSimCount() { + return simCount; + } } diff --git a/Mage.Server.Plugins/Mage.Player.AIMCTS/src/mage/player/ai/MCTSNode.java b/Mage.Server.Plugins/Mage.Player.AIMCTS/src/mage/player/ai/MCTSNode.java index bebd0914a7..28f28644ab 100644 --- a/Mage.Server.Plugins/Mage.Player.AIMCTS/src/mage/player/ai/MCTSNode.java +++ b/Mage.Server.Plugins/Mage.Player.AIMCTS/src/mage/player/ai/MCTSNode.java @@ -51,14 +51,14 @@ import org.apache.log4j.Logger; */ public class MCTSNode { - private static final double selectionCoefficient = 1.0; + private static final double selectionCoefficient = Math.sqrt(2.0); private static final double passRatioTolerance = 0.0; - private static final transient Logger logger = Logger.getLogger(MCTSNode.class); + private static final transient Logger logger = Logger.getLogger(MCTSNode.class); private int visits = 0; private int wins = 0; private MCTSNode parent; - private final List children = new ArrayList(); + private final List children = new ArrayList<>(); private Ability action; private Game game; private Combat combat; @@ -74,6 +74,7 @@ public class MCTSNode { this.terminal = game.gameOver(null); setPlayer(); nodeCount = 1; +// logger.info(this.stateValue); } protected MCTSNode(MCTSNode parent, Game game, Ability action) { @@ -84,6 +85,7 @@ public class MCTSNode { this.action = action; setPlayer(); nodeCount++; +// logger.info(this.stateValue); } protected MCTSNode(MCTSNode parent, Game game, Combat combat) { @@ -94,6 +96,7 @@ public class MCTSNode { this.parent = parent; setPlayer(); nodeCount++; +// logger.info(this.stateValue); } private void setPlayer() { @@ -356,11 +359,10 @@ public class MCTSNode { * performs a breadth first search for a matching game state * * @param state - the game state that we are looking for - * @param nextAction - the next action that will be performed * @return the matching state or null if no match is found */ public MCTSNode getMatchingState(String state) { - ArrayDeque queue = new ArrayDeque(); + ArrayDeque queue = new ArrayDeque<>(); queue.add(this); while (!queue.isEmpty()) { @@ -376,14 +378,15 @@ public class MCTSNode { public void merge(MCTSNode merge) { if (!stateValue.equals(merge.stateValue)) { - logger.info("mismatched merge states"); + logger.info("mismatched merge states at root"); return; } this.visits += merge.visits; this.wins += merge.wins; - - List mergeChildren = new ArrayList(); + int mismatchCount = 0; + + List mergeChildren = new ArrayList<>(); for (MCTSNode child: merge.children) { mergeChildren.add(child); } @@ -393,8 +396,9 @@ public class MCTSNode { if (mergeChild.action != null && child.action != null) { if (mergeChild.action.toString().equals(child.action.toString())) { if (!mergeChild.stateValue.equals(child.stateValue)) { - logger.info("mismatched merge states"); - mergeChildren.remove(mergeChild); + mismatchCount++; +// logger.info("mismatched merge states"); +// mergeChildren.remove(mergeChild); } else { child.merge(mergeChild); @@ -406,8 +410,9 @@ public class MCTSNode { else { if (mergeChild.combat.getValue().equals(child.combat.getValue())) { if (!mergeChild.stateValue.equals(child.stateValue)) { - logger.info("mismatched merge states"); - mergeChildren.remove(mergeChild); + mismatchCount++; +// logger.info("mismatched merge states"); +// mergeChildren.remove(mergeChild); } else { child.merge(mergeChild); @@ -424,6 +429,8 @@ public class MCTSNode { children.add(child); } } +// if (mismatchCount > 0) +// logger.info("mismatched merge states: " + mismatchCount); } // public void print(int depth) {