mirror of
https://github.com/correl/mage.git
synced 2024-11-25 03:00:11 +00:00
more MCTS improvements
This commit is contained in:
parent
6403fff12b
commit
c5e216ddbf
4 changed files with 199 additions and 12 deletions
|
@ -92,6 +92,11 @@ public class ComputerPlayerMCTS extends ComputerPlayer implements Player {
|
||||||
if (!lastPhase.equals(game.getTurn().getValue(game.getTurnNum()))) {
|
if (!lastPhase.equals(game.getTurn().getValue(game.getTurnNum()))) {
|
||||||
logList(game.getTurn().getValue(game.getTurnNum()) + name + " hand: ", new ArrayList(hand.getCards(game)));
|
logList(game.getTurn().getValue(game.getTurnNum()) + name + " hand: ", new ArrayList(hand.getCards(game)));
|
||||||
lastPhase = game.getTurn().getValue(game.getTurnNum());
|
lastPhase = game.getTurn().getValue(game.getTurnNum());
|
||||||
|
if (MCTSNode.USE_ACTION_CACHE) {
|
||||||
|
int count = MCTSNode.cleanupCache(game.getTurnNum());
|
||||||
|
if (count > 0)
|
||||||
|
logger.info("Removed " + count + " cache entries");
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
game.getState().setPriorityPlayerId(playerId);
|
game.getState().setPriorityPlayerId(playerId);
|
||||||
|
@ -113,7 +118,7 @@ public class ComputerPlayerMCTS extends ComputerPlayer implements Player {
|
||||||
Game sim = createMCTSGame(game);
|
Game sim = createMCTSGame(game);
|
||||||
MCTSPlayer player = (MCTSPlayer) sim.getPlayer(playerId);
|
MCTSPlayer player = (MCTSPlayer) sim.getPlayer(playerId);
|
||||||
player.setNextAction(action);
|
player.setNextAction(action);
|
||||||
root = new MCTSNode(sim);
|
root = new MCTSNode(playerId, sim);
|
||||||
}
|
}
|
||||||
applyMCTS(game, action);
|
applyMCTS(game, action);
|
||||||
root = root.bestChild();
|
root = root.bestChild();
|
||||||
|
@ -123,7 +128,7 @@ public class ComputerPlayerMCTS extends ComputerPlayer implements Player {
|
||||||
protected void getNextAction(Game game, NextAction nextAction) {
|
protected void getNextAction(Game game, NextAction nextAction) {
|
||||||
if (root != null) {
|
if (root != null) {
|
||||||
MCTSNode newRoot;
|
MCTSNode newRoot;
|
||||||
newRoot = root.getMatchingState(game.getState().getValue(false, game));
|
newRoot = root.getMatchingState(game.getState().getValue(game, playerId));
|
||||||
if (newRoot != null) {
|
if (newRoot != null) {
|
||||||
newRoot.emancipate();
|
newRoot.emancipate();
|
||||||
}
|
}
|
||||||
|
@ -213,6 +218,7 @@ public class ComputerPlayerMCTS extends ComputerPlayer implements Player {
|
||||||
sb.append(game.getPermanent(attackerId).getName()).append(",");
|
sb.append(game.getPermanent(attackerId).getName()).append(",");
|
||||||
}
|
}
|
||||||
logger.info(sb.toString());
|
logger.info(sb.toString());
|
||||||
|
MCTSNode.logHitMiss();
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@ -233,6 +239,7 @@ public class ComputerPlayerMCTS extends ComputerPlayer implements Player {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
logger.info(sb.toString());
|
logger.info(sb.toString());
|
||||||
|
MCTSNode.logHitMiss();
|
||||||
}
|
}
|
||||||
|
|
||||||
// @Override
|
// @Override
|
||||||
|
@ -273,6 +280,7 @@ public class ComputerPlayerMCTS extends ComputerPlayer implements Player {
|
||||||
protected long totalThinkTime = 0;
|
protected long totalThinkTime = 0;
|
||||||
protected long totalSimulations = 0;
|
protected long totalSimulations = 0;
|
||||||
protected void applyMCTS(final Game game, final NextAction action) {
|
protected void applyMCTS(final Game game, final NextAction action) {
|
||||||
|
|
||||||
int thinkTime = calculateThinkTime(game, action);
|
int thinkTime = calculateThinkTime(game, action);
|
||||||
|
|
||||||
if (thinkTime > 0) {
|
if (thinkTime > 0) {
|
||||||
|
@ -306,6 +314,7 @@ public class ComputerPlayerMCTS extends ComputerPlayer implements Player {
|
||||||
totalSimulations += simCount;
|
totalSimulations += simCount;
|
||||||
logger.info("Player: " + name + " Simulated " + simCount + " games in " + thinkTime + " seconds - nodes in tree: " + root.size());
|
logger.info("Player: " + name + " Simulated " + simCount + " games in " + thinkTime + " seconds - nodes in tree: " + root.size());
|
||||||
logger.info("Total: Simulated " + totalSimulations + " games in " + totalThinkTime + " seconds - Average: " + totalSimulations/totalThinkTime);
|
logger.info("Total: Simulated " + totalSimulations + " games in " + totalThinkTime + " seconds - Average: " + totalSimulations/totalThinkTime);
|
||||||
|
MCTSNode.logHitMiss();
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
long startTime = System.nanoTime();
|
long startTime = System.nanoTime();
|
||||||
|
|
|
@ -48,7 +48,7 @@ public class MCTSExecutor implements Callable<Boolean> {
|
||||||
public MCTSExecutor(Game sim, UUID playerId, int thinkTime) {
|
public MCTSExecutor(Game sim, UUID playerId, int thinkTime) {
|
||||||
this.playerId = playerId;
|
this.playerId = playerId;
|
||||||
this.thinkTime = thinkTime;
|
this.thinkTime = thinkTime;
|
||||||
root = new MCTSNode(sim);
|
root = new MCTSNode(playerId, sim);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|
|
@ -29,8 +29,11 @@ package mage.player.ai;
|
||||||
|
|
||||||
import java.util.ArrayDeque;
|
import java.util.ArrayDeque;
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
|
import java.util.Iterator;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
import java.util.Set;
|
||||||
import java.util.UUID;
|
import java.util.UUID;
|
||||||
|
import java.util.concurrent.ConcurrentHashMap;
|
||||||
import mage.constants.PhaseStep;
|
import mage.constants.PhaseStep;
|
||||||
import mage.constants.Zone;
|
import mage.constants.Zone;
|
||||||
import mage.abilities.Ability;
|
import mage.abilities.Ability;
|
||||||
|
@ -51,6 +54,7 @@ import org.apache.log4j.Logger;
|
||||||
*/
|
*/
|
||||||
public class MCTSNode {
|
public class MCTSNode {
|
||||||
|
|
||||||
|
public static final boolean USE_ACTION_CACHE = false;
|
||||||
private static final double selectionCoefficient = Math.sqrt(2.0);
|
private static final double selectionCoefficient = Math.sqrt(2.0);
|
||||||
private static final double passRatioTolerance = 0.0;
|
private static final double passRatioTolerance = 0.0;
|
||||||
private static final transient Logger logger = Logger.getLogger(MCTSNode.class);
|
private static final transient Logger logger = Logger.getLogger(MCTSNode.class);
|
||||||
|
@ -63,14 +67,18 @@ public class MCTSNode {
|
||||||
private Game game;
|
private Game game;
|
||||||
private Combat combat;
|
private Combat combat;
|
||||||
private final String stateValue;
|
private final String stateValue;
|
||||||
|
private final String fullStateValue;
|
||||||
private UUID playerId;
|
private UUID playerId;
|
||||||
private boolean terminal = false;
|
private boolean terminal = false;
|
||||||
|
private UUID targetPlayer;
|
||||||
|
|
||||||
private static int nodeCount;
|
private static int nodeCount;
|
||||||
|
|
||||||
public MCTSNode(Game game) {
|
public MCTSNode(UUID targetPlayer, Game game) {
|
||||||
|
this.targetPlayer = targetPlayer;
|
||||||
this.game = game;
|
this.game = game;
|
||||||
this.stateValue = game.getState().getValue(false, game);
|
this.stateValue = game.getState().getValue(game, targetPlayer);
|
||||||
|
this.fullStateValue = game.getState().getValue(true, game);
|
||||||
this.terminal = game.gameOver(null);
|
this.terminal = game.gameOver(null);
|
||||||
setPlayer();
|
setPlayer();
|
||||||
nodeCount = 1;
|
nodeCount = 1;
|
||||||
|
@ -78,8 +86,10 @@ public class MCTSNode {
|
||||||
}
|
}
|
||||||
|
|
||||||
protected MCTSNode(MCTSNode parent, Game game, Ability action) {
|
protected MCTSNode(MCTSNode parent, Game game, Ability action) {
|
||||||
|
this.targetPlayer = parent.targetPlayer;
|
||||||
this.game = game;
|
this.game = game;
|
||||||
this.stateValue = game.getState().getValue(false, game);
|
this.stateValue = game.getState().getValue(game, targetPlayer);
|
||||||
|
this.fullStateValue = game.getState().getValue(true, game);
|
||||||
this.terminal = game.gameOver(null);
|
this.terminal = game.gameOver(null);
|
||||||
this.parent = parent;
|
this.parent = parent;
|
||||||
this.action = action;
|
this.action = action;
|
||||||
|
@ -89,9 +99,11 @@ public class MCTSNode {
|
||||||
}
|
}
|
||||||
|
|
||||||
protected MCTSNode(MCTSNode parent, Game game, Combat combat) {
|
protected MCTSNode(MCTSNode parent, Game game, Combat combat) {
|
||||||
|
this.targetPlayer = parent.targetPlayer;
|
||||||
this.game = game;
|
this.game = game;
|
||||||
this.combat = combat;
|
this.combat = combat;
|
||||||
this.stateValue = game.getState().getValue(false, game);
|
this.stateValue = game.getState().getValue(game, targetPlayer);
|
||||||
|
this.fullStateValue = game.getState().getValue(true, game);
|
||||||
this.terminal = game.gameOver(null);
|
this.terminal = game.gameOver(null);
|
||||||
this.parent = parent;
|
this.parent = parent;
|
||||||
setPlayer();
|
setPlayer();
|
||||||
|
@ -144,7 +156,11 @@ public class MCTSNode {
|
||||||
switch (player.getNextAction()) {
|
switch (player.getNextAction()) {
|
||||||
case PRIORITY:
|
case PRIORITY:
|
||||||
// logger.info("Priority for player:" + player.getName() + " turn: " + game.getTurnNum() + " phase: " + game.getPhase().getType() + " step: " + game.getStep().getType());
|
// logger.info("Priority for player:" + player.getName() + " turn: " + game.getTurnNum() + " phase: " + game.getPhase().getType() + " step: " + game.getStep().getType());
|
||||||
List<Ability> abilities = player.getPlayableOptions(game);
|
List<Ability> abilities;
|
||||||
|
if (!USE_ACTION_CACHE)
|
||||||
|
abilities = player.getPlayableOptions(game);
|
||||||
|
else
|
||||||
|
abilities = getPlayables(player, fullStateValue, game);
|
||||||
for (Ability ability: abilities) {
|
for (Ability ability: abilities) {
|
||||||
Game sim = game.copy();
|
Game sim = game.copy();
|
||||||
// logger.info("expand " + ability.toString());
|
// logger.info("expand " + ability.toString());
|
||||||
|
@ -156,7 +172,11 @@ public class MCTSNode {
|
||||||
break;
|
break;
|
||||||
case SELECT_ATTACKERS:
|
case SELECT_ATTACKERS:
|
||||||
// logger.info("Select attackers:" + player.getName());
|
// logger.info("Select attackers:" + player.getName());
|
||||||
List<List<UUID>> attacks = player.getAttacks(game);
|
List<List<UUID>> attacks;
|
||||||
|
if (!USE_ACTION_CACHE)
|
||||||
|
attacks = player.getAttacks(game);
|
||||||
|
else
|
||||||
|
attacks = getAttacks(player, fullStateValue, game);
|
||||||
UUID defenderId = game.getOpponents(player.getId()).iterator().next();
|
UUID defenderId = game.getOpponents(player.getId()).iterator().next();
|
||||||
for (List<UUID> attack: attacks) {
|
for (List<UUID> attack: attacks) {
|
||||||
Game sim = game.copy();
|
Game sim = game.copy();
|
||||||
|
@ -170,7 +190,11 @@ public class MCTSNode {
|
||||||
break;
|
break;
|
||||||
case SELECT_BLOCKERS:
|
case SELECT_BLOCKERS:
|
||||||
// logger.info("Select blockers:" + player.getName());
|
// logger.info("Select blockers:" + player.getName());
|
||||||
List<List<List<UUID>>> blocks = player.getBlocks(game);
|
List<List<List<UUID>>> blocks;
|
||||||
|
if (!USE_ACTION_CACHE)
|
||||||
|
blocks = player.getBlocks(game);
|
||||||
|
else
|
||||||
|
blocks = getBlocks(player, fullStateValue, game);
|
||||||
for (List<List<UUID>> block: blocks) {
|
for (List<List<UUID>> block: blocks) {
|
||||||
Game sim = game.copy();
|
Game sim = game.copy();
|
||||||
MCTSPlayer simPlayer = (MCTSPlayer) sim.getPlayer(player.getId());
|
MCTSPlayer simPlayer = (MCTSPlayer) sim.getPlayer(player.getId());
|
||||||
|
@ -454,4 +478,99 @@ public class MCTSNode {
|
||||||
return num;
|
return num;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private static final ConcurrentHashMap<String, List<Ability>> playablesCache = new ConcurrentHashMap<>();
|
||||||
|
private static final ConcurrentHashMap<String, List<List<UUID>>> attacksCache = new ConcurrentHashMap<>();
|
||||||
|
private static final ConcurrentHashMap<String, List<List<List<UUID>>>> blocksCache = new ConcurrentHashMap<>();
|
||||||
|
|
||||||
|
private static long playablesHit = 0;
|
||||||
|
private static long playablesMiss = 0;
|
||||||
|
private static long attacksHit = 0;
|
||||||
|
private static long attacksMiss = 0;
|
||||||
|
private static long blocksHit = 0;
|
||||||
|
private static long blocksMiss = 0;
|
||||||
|
|
||||||
|
private static List<Ability> getPlayables(MCTSPlayer player, String state, Game game) {
|
||||||
|
if (playablesCache.containsKey(state)) {
|
||||||
|
playablesHit++;
|
||||||
|
return playablesCache.get(state);
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
playablesMiss++;
|
||||||
|
List<Ability> abilities = player.getPlayableOptions(game);
|
||||||
|
playablesCache.put(state, abilities);
|
||||||
|
return abilities;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private static List<List<UUID>> getAttacks(MCTSPlayer player, String state, Game game) {
|
||||||
|
if (attacksCache.containsKey(state)) {
|
||||||
|
attacksHit++;
|
||||||
|
return attacksCache.get(state);
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
attacksMiss++;
|
||||||
|
List<List<UUID>> attacks = player.getAttacks(game);
|
||||||
|
attacksCache.put(state, attacks);
|
||||||
|
return attacks;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private static List<List<List<UUID>>> getBlocks(MCTSPlayer player, String state, Game game) {
|
||||||
|
if (blocksCache.containsKey(state)) {
|
||||||
|
blocksHit++;
|
||||||
|
return blocksCache.get(state);
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
blocksMiss++;
|
||||||
|
List<List<List<UUID>>> blocks = player.getBlocks(game);
|
||||||
|
blocksCache.put(state, blocks);
|
||||||
|
return blocks;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public static int cleanupCache(int turnNum) {
|
||||||
|
Set<String> playablesKeys = playablesCache.keySet();
|
||||||
|
Iterator<String> playablesIterator = playablesKeys.iterator();
|
||||||
|
int count = 0;
|
||||||
|
while(playablesIterator.hasNext()) {
|
||||||
|
String next = playablesIterator.next();
|
||||||
|
int cacheTurn = Integer.valueOf(next.split(":", 2)[0].substring(1));
|
||||||
|
if (cacheTurn < turnNum) {
|
||||||
|
playablesIterator.remove();
|
||||||
|
count++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Set<String> attacksKeys = attacksCache.keySet();
|
||||||
|
Iterator<String> attacksIterator = attacksKeys.iterator();
|
||||||
|
while(attacksIterator.hasNext()) {
|
||||||
|
int cacheTurn = Integer.valueOf(attacksIterator.next().split(":", 2)[0].substring(1));
|
||||||
|
if (cacheTurn < turnNum) {
|
||||||
|
attacksIterator.remove();
|
||||||
|
count++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Set<String> blocksKeys = blocksCache.keySet();
|
||||||
|
Iterator<String> blocksIterator = blocksKeys.iterator();
|
||||||
|
while(blocksIterator.hasNext()) {
|
||||||
|
int cacheTurn = Integer.valueOf(blocksIterator.next().split(":", 2)[0].substring(1));
|
||||||
|
if (cacheTurn < turnNum) {
|
||||||
|
blocksIterator.remove();
|
||||||
|
count++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return count;
|
||||||
|
}
|
||||||
|
|
||||||
|
public static void logHitMiss() {
|
||||||
|
if (USE_ACTION_CACHE) {
|
||||||
|
StringBuilder sb = new StringBuilder();
|
||||||
|
sb.append("Playables Cache -- Hits: ").append(playablesHit).append(" Misses: ").append(playablesMiss).append("\n");
|
||||||
|
sb.append("Attacks Cache -- Hits: ").append(attacksHit).append(" Misses: ").append(attacksMiss).append("\n");
|
||||||
|
sb.append("Blocks Cache -- Hits: ").append(blocksHit).append(" Misses: ").append(blocksMiss).append("\n");
|
||||||
|
logger.info(sb.toString());
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -226,7 +226,66 @@ public class GameState implements Serializable, Copyable<GameState> {
|
||||||
|
|
||||||
for (Player player: players.values()) {
|
for (Player player: players.values()) {
|
||||||
sb.append("player").append(player.isPassed()).append(player.getLife()).append("hand");
|
sb.append("player").append(player.isPassed()).append(player.getLife()).append("hand");
|
||||||
if (useHidden && priorityPlayerId == player.getId()) {
|
if (useHidden) {
|
||||||
|
sb.append(player.getHand().getValue(game));
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
sb.append(player.getHand().size());
|
||||||
|
}
|
||||||
|
sb.append("library").append(player.getLibrary().size());
|
||||||
|
sb.append("graveyard");
|
||||||
|
sb.append(player.getGraveyard().getValue(game));
|
||||||
|
}
|
||||||
|
|
||||||
|
sb.append("permanents");
|
||||||
|
List<String> perms = new ArrayList<>();
|
||||||
|
for (Permanent permanent: battlefield.getAllPermanents()) {
|
||||||
|
perms.add(permanent.getValue());
|
||||||
|
}
|
||||||
|
Collections.sort(perms);
|
||||||
|
sb.append(perms);
|
||||||
|
|
||||||
|
sb.append("spells");
|
||||||
|
for (StackObject spell: stack) {
|
||||||
|
sb.append(spell.getControllerId()).append(spell.getName());
|
||||||
|
sb.append(spell.getStackAbility().toString());
|
||||||
|
for (Mode mode: spell.getStackAbility().getModes().values()) {
|
||||||
|
if (!mode.getTargets().isEmpty()) {
|
||||||
|
sb.append("targets");
|
||||||
|
for (Target target: mode.getTargets()) {
|
||||||
|
sb.append(target.getTargets());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (!mode.getChoices().isEmpty()) {
|
||||||
|
sb.append("choices");
|
||||||
|
for (Choice choice: mode.getChoices()) {
|
||||||
|
sb.append(choice.getChoice());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for (ExileZone zone: exile.getExileZones()) {
|
||||||
|
sb.append("exile").append(zone.getName()).append(zone.getValue(game));
|
||||||
|
}
|
||||||
|
|
||||||
|
sb.append("combat");
|
||||||
|
for (CombatGroup group: combat.getGroups()) {
|
||||||
|
sb.append(group.getDefenderId()).append(group.getAttackers()).append(group.getBlockers());
|
||||||
|
}
|
||||||
|
|
||||||
|
return sb.toString();
|
||||||
|
}
|
||||||
|
|
||||||
|
public String getValue(Game game, UUID playerId) {
|
||||||
|
StringBuilder sb = threadLocalBuilder.get();
|
||||||
|
|
||||||
|
sb.append(turn.getValue(turnNum));
|
||||||
|
sb.append(activePlayerId).append(priorityPlayerId);
|
||||||
|
|
||||||
|
for (Player player: players.values()) {
|
||||||
|
sb.append("player").append(player.isPassed()).append(player.getLife()).append("hand");
|
||||||
|
if (playerId == player.getId()) {
|
||||||
sb.append(player.getHand().getValue(game));
|
sb.append(player.getHand().getValue(game));
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
|
|
Loading…
Reference in a new issue