mirror of
https://github.com/correl/mage.git
synced 2024-11-28 19:19:55 +00:00
more MCTS improvements
This commit is contained in:
parent
6403fff12b
commit
c5e216ddbf
4 changed files with 199 additions and 12 deletions
|
@ -92,6 +92,11 @@ public class ComputerPlayerMCTS extends ComputerPlayer implements Player {
|
|||
if (!lastPhase.equals(game.getTurn().getValue(game.getTurnNum()))) {
|
||||
logList(game.getTurn().getValue(game.getTurnNum()) + name + " hand: ", new ArrayList(hand.getCards(game)));
|
||||
lastPhase = game.getTurn().getValue(game.getTurnNum());
|
||||
if (MCTSNode.USE_ACTION_CACHE) {
|
||||
int count = MCTSNode.cleanupCache(game.getTurnNum());
|
||||
if (count > 0)
|
||||
logger.info("Removed " + count + " cache entries");
|
||||
}
|
||||
}
|
||||
}
|
||||
game.getState().setPriorityPlayerId(playerId);
|
||||
|
@ -113,7 +118,7 @@ public class ComputerPlayerMCTS extends ComputerPlayer implements Player {
|
|||
Game sim = createMCTSGame(game);
|
||||
MCTSPlayer player = (MCTSPlayer) sim.getPlayer(playerId);
|
||||
player.setNextAction(action);
|
||||
root = new MCTSNode(sim);
|
||||
root = new MCTSNode(playerId, sim);
|
||||
}
|
||||
applyMCTS(game, action);
|
||||
root = root.bestChild();
|
||||
|
@ -123,7 +128,7 @@ public class ComputerPlayerMCTS extends ComputerPlayer implements Player {
|
|||
protected void getNextAction(Game game, NextAction nextAction) {
|
||||
if (root != null) {
|
||||
MCTSNode newRoot;
|
||||
newRoot = root.getMatchingState(game.getState().getValue(false, game));
|
||||
newRoot = root.getMatchingState(game.getState().getValue(game, playerId));
|
||||
if (newRoot != null) {
|
||||
newRoot.emancipate();
|
||||
}
|
||||
|
@ -213,6 +218,7 @@ public class ComputerPlayerMCTS extends ComputerPlayer implements Player {
|
|||
sb.append(game.getPermanent(attackerId).getName()).append(",");
|
||||
}
|
||||
logger.info(sb.toString());
|
||||
MCTSNode.logHitMiss();
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -233,6 +239,7 @@ public class ComputerPlayerMCTS extends ComputerPlayer implements Player {
|
|||
}
|
||||
}
|
||||
logger.info(sb.toString());
|
||||
MCTSNode.logHitMiss();
|
||||
}
|
||||
|
||||
// @Override
|
||||
|
@ -273,6 +280,7 @@ public class ComputerPlayerMCTS extends ComputerPlayer implements Player {
|
|||
protected long totalThinkTime = 0;
|
||||
protected long totalSimulations = 0;
|
||||
protected void applyMCTS(final Game game, final NextAction action) {
|
||||
|
||||
int thinkTime = calculateThinkTime(game, action);
|
||||
|
||||
if (thinkTime > 0) {
|
||||
|
@ -306,6 +314,7 @@ public class ComputerPlayerMCTS extends ComputerPlayer implements Player {
|
|||
totalSimulations += simCount;
|
||||
logger.info("Player: " + name + " Simulated " + simCount + " games in " + thinkTime + " seconds - nodes in tree: " + root.size());
|
||||
logger.info("Total: Simulated " + totalSimulations + " games in " + totalThinkTime + " seconds - Average: " + totalSimulations/totalThinkTime);
|
||||
MCTSNode.logHitMiss();
|
||||
}
|
||||
else {
|
||||
long startTime = System.nanoTime();
|
||||
|
|
|
@ -48,7 +48,7 @@ public class MCTSExecutor implements Callable<Boolean> {
|
|||
public MCTSExecutor(Game sim, UUID playerId, int thinkTime) {
|
||||
this.playerId = playerId;
|
||||
this.thinkTime = thinkTime;
|
||||
root = new MCTSNode(sim);
|
||||
root = new MCTSNode(playerId, sim);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -29,8 +29,11 @@ package mage.player.ai;
|
|||
|
||||
import java.util.ArrayDeque;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Iterator;
|
||||
import java.util.List;
|
||||
import java.util.Set;
|
||||
import java.util.UUID;
|
||||
import java.util.concurrent.ConcurrentHashMap;
|
||||
import mage.constants.PhaseStep;
|
||||
import mage.constants.Zone;
|
||||
import mage.abilities.Ability;
|
||||
|
@ -51,6 +54,7 @@ import org.apache.log4j.Logger;
|
|||
*/
|
||||
public class MCTSNode {
|
||||
|
||||
public static final boolean USE_ACTION_CACHE = false;
|
||||
private static final double selectionCoefficient = Math.sqrt(2.0);
|
||||
private static final double passRatioTolerance = 0.0;
|
||||
private static final transient Logger logger = Logger.getLogger(MCTSNode.class);
|
||||
|
@ -63,14 +67,18 @@ public class MCTSNode {
|
|||
private Game game;
|
||||
private Combat combat;
|
||||
private final String stateValue;
|
||||
private final String fullStateValue;
|
||||
private UUID playerId;
|
||||
private boolean terminal = false;
|
||||
private UUID targetPlayer;
|
||||
|
||||
private static int nodeCount;
|
||||
|
||||
public MCTSNode(Game game) {
|
||||
public MCTSNode(UUID targetPlayer, Game game) {
|
||||
this.targetPlayer = targetPlayer;
|
||||
this.game = game;
|
||||
this.stateValue = game.getState().getValue(false, game);
|
||||
this.stateValue = game.getState().getValue(game, targetPlayer);
|
||||
this.fullStateValue = game.getState().getValue(true, game);
|
||||
this.terminal = game.gameOver(null);
|
||||
setPlayer();
|
||||
nodeCount = 1;
|
||||
|
@ -78,8 +86,10 @@ public class MCTSNode {
|
|||
}
|
||||
|
||||
protected MCTSNode(MCTSNode parent, Game game, Ability action) {
|
||||
this.targetPlayer = parent.targetPlayer;
|
||||
this.game = game;
|
||||
this.stateValue = game.getState().getValue(false, game);
|
||||
this.stateValue = game.getState().getValue(game, targetPlayer);
|
||||
this.fullStateValue = game.getState().getValue(true, game);
|
||||
this.terminal = game.gameOver(null);
|
||||
this.parent = parent;
|
||||
this.action = action;
|
||||
|
@ -89,9 +99,11 @@ public class MCTSNode {
|
|||
}
|
||||
|
||||
protected MCTSNode(MCTSNode parent, Game game, Combat combat) {
|
||||
this.targetPlayer = parent.targetPlayer;
|
||||
this.game = game;
|
||||
this.combat = combat;
|
||||
this.stateValue = game.getState().getValue(false, game);
|
||||
this.stateValue = game.getState().getValue(game, targetPlayer);
|
||||
this.fullStateValue = game.getState().getValue(true, game);
|
||||
this.terminal = game.gameOver(null);
|
||||
this.parent = parent;
|
||||
setPlayer();
|
||||
|
@ -144,7 +156,11 @@ public class MCTSNode {
|
|||
switch (player.getNextAction()) {
|
||||
case PRIORITY:
|
||||
// logger.info("Priority for player:" + player.getName() + " turn: " + game.getTurnNum() + " phase: " + game.getPhase().getType() + " step: " + game.getStep().getType());
|
||||
List<Ability> abilities = player.getPlayableOptions(game);
|
||||
List<Ability> abilities;
|
||||
if (!USE_ACTION_CACHE)
|
||||
abilities = player.getPlayableOptions(game);
|
||||
else
|
||||
abilities = getPlayables(player, fullStateValue, game);
|
||||
for (Ability ability: abilities) {
|
||||
Game sim = game.copy();
|
||||
// logger.info("expand " + ability.toString());
|
||||
|
@ -156,7 +172,11 @@ public class MCTSNode {
|
|||
break;
|
||||
case SELECT_ATTACKERS:
|
||||
// logger.info("Select attackers:" + player.getName());
|
||||
List<List<UUID>> attacks = player.getAttacks(game);
|
||||
List<List<UUID>> attacks;
|
||||
if (!USE_ACTION_CACHE)
|
||||
attacks = player.getAttacks(game);
|
||||
else
|
||||
attacks = getAttacks(player, fullStateValue, game);
|
||||
UUID defenderId = game.getOpponents(player.getId()).iterator().next();
|
||||
for (List<UUID> attack: attacks) {
|
||||
Game sim = game.copy();
|
||||
|
@ -170,7 +190,11 @@ public class MCTSNode {
|
|||
break;
|
||||
case SELECT_BLOCKERS:
|
||||
// logger.info("Select blockers:" + player.getName());
|
||||
List<List<List<UUID>>> blocks = player.getBlocks(game);
|
||||
List<List<List<UUID>>> blocks;
|
||||
if (!USE_ACTION_CACHE)
|
||||
blocks = player.getBlocks(game);
|
||||
else
|
||||
blocks = getBlocks(player, fullStateValue, game);
|
||||
for (List<List<UUID>> block: blocks) {
|
||||
Game sim = game.copy();
|
||||
MCTSPlayer simPlayer = (MCTSPlayer) sim.getPlayer(player.getId());
|
||||
|
@ -454,4 +478,99 @@ public class MCTSNode {
|
|||
return num;
|
||||
}
|
||||
|
||||
private static final ConcurrentHashMap<String, List<Ability>> playablesCache = new ConcurrentHashMap<>();
|
||||
private static final ConcurrentHashMap<String, List<List<UUID>>> attacksCache = new ConcurrentHashMap<>();
|
||||
private static final ConcurrentHashMap<String, List<List<List<UUID>>>> blocksCache = new ConcurrentHashMap<>();
|
||||
|
||||
private static long playablesHit = 0;
|
||||
private static long playablesMiss = 0;
|
||||
private static long attacksHit = 0;
|
||||
private static long attacksMiss = 0;
|
||||
private static long blocksHit = 0;
|
||||
private static long blocksMiss = 0;
|
||||
|
||||
private static List<Ability> getPlayables(MCTSPlayer player, String state, Game game) {
|
||||
if (playablesCache.containsKey(state)) {
|
||||
playablesHit++;
|
||||
return playablesCache.get(state);
|
||||
}
|
||||
else {
|
||||
playablesMiss++;
|
||||
List<Ability> abilities = player.getPlayableOptions(game);
|
||||
playablesCache.put(state, abilities);
|
||||
return abilities;
|
||||
}
|
||||
}
|
||||
|
||||
private static List<List<UUID>> getAttacks(MCTSPlayer player, String state, Game game) {
|
||||
if (attacksCache.containsKey(state)) {
|
||||
attacksHit++;
|
||||
return attacksCache.get(state);
|
||||
}
|
||||
else {
|
||||
attacksMiss++;
|
||||
List<List<UUID>> attacks = player.getAttacks(game);
|
||||
attacksCache.put(state, attacks);
|
||||
return attacks;
|
||||
}
|
||||
}
|
||||
|
||||
private static List<List<List<UUID>>> getBlocks(MCTSPlayer player, String state, Game game) {
|
||||
if (blocksCache.containsKey(state)) {
|
||||
blocksHit++;
|
||||
return blocksCache.get(state);
|
||||
}
|
||||
else {
|
||||
blocksMiss++;
|
||||
List<List<List<UUID>>> blocks = player.getBlocks(game);
|
||||
blocksCache.put(state, blocks);
|
||||
return blocks;
|
||||
}
|
||||
}
|
||||
|
||||
public static int cleanupCache(int turnNum) {
|
||||
Set<String> playablesKeys = playablesCache.keySet();
|
||||
Iterator<String> playablesIterator = playablesKeys.iterator();
|
||||
int count = 0;
|
||||
while(playablesIterator.hasNext()) {
|
||||
String next = playablesIterator.next();
|
||||
int cacheTurn = Integer.valueOf(next.split(":", 2)[0].substring(1));
|
||||
if (cacheTurn < turnNum) {
|
||||
playablesIterator.remove();
|
||||
count++;
|
||||
}
|
||||
}
|
||||
|
||||
Set<String> attacksKeys = attacksCache.keySet();
|
||||
Iterator<String> attacksIterator = attacksKeys.iterator();
|
||||
while(attacksIterator.hasNext()) {
|
||||
int cacheTurn = Integer.valueOf(attacksIterator.next().split(":", 2)[0].substring(1));
|
||||
if (cacheTurn < turnNum) {
|
||||
attacksIterator.remove();
|
||||
count++;
|
||||
}
|
||||
}
|
||||
|
||||
Set<String> blocksKeys = blocksCache.keySet();
|
||||
Iterator<String> blocksIterator = blocksKeys.iterator();
|
||||
while(blocksIterator.hasNext()) {
|
||||
int cacheTurn = Integer.valueOf(blocksIterator.next().split(":", 2)[0].substring(1));
|
||||
if (cacheTurn < turnNum) {
|
||||
blocksIterator.remove();
|
||||
count++;
|
||||
}
|
||||
}
|
||||
|
||||
return count;
|
||||
}
|
||||
|
||||
public static void logHitMiss() {
|
||||
if (USE_ACTION_CACHE) {
|
||||
StringBuilder sb = new StringBuilder();
|
||||
sb.append("Playables Cache -- Hits: ").append(playablesHit).append(" Misses: ").append(playablesMiss).append("\n");
|
||||
sb.append("Attacks Cache -- Hits: ").append(attacksHit).append(" Misses: ").append(attacksMiss).append("\n");
|
||||
sb.append("Blocks Cache -- Hits: ").append(blocksHit).append(" Misses: ").append(blocksMiss).append("\n");
|
||||
logger.info(sb.toString());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -226,7 +226,66 @@ public class GameState implements Serializable, Copyable<GameState> {
|
|||
|
||||
for (Player player: players.values()) {
|
||||
sb.append("player").append(player.isPassed()).append(player.getLife()).append("hand");
|
||||
if (useHidden && priorityPlayerId == player.getId()) {
|
||||
if (useHidden) {
|
||||
sb.append(player.getHand().getValue(game));
|
||||
}
|
||||
else {
|
||||
sb.append(player.getHand().size());
|
||||
}
|
||||
sb.append("library").append(player.getLibrary().size());
|
||||
sb.append("graveyard");
|
||||
sb.append(player.getGraveyard().getValue(game));
|
||||
}
|
||||
|
||||
sb.append("permanents");
|
||||
List<String> perms = new ArrayList<>();
|
||||
for (Permanent permanent: battlefield.getAllPermanents()) {
|
||||
perms.add(permanent.getValue());
|
||||
}
|
||||
Collections.sort(perms);
|
||||
sb.append(perms);
|
||||
|
||||
sb.append("spells");
|
||||
for (StackObject spell: stack) {
|
||||
sb.append(spell.getControllerId()).append(spell.getName());
|
||||
sb.append(spell.getStackAbility().toString());
|
||||
for (Mode mode: spell.getStackAbility().getModes().values()) {
|
||||
if (!mode.getTargets().isEmpty()) {
|
||||
sb.append("targets");
|
||||
for (Target target: mode.getTargets()) {
|
||||
sb.append(target.getTargets());
|
||||
}
|
||||
}
|
||||
if (!mode.getChoices().isEmpty()) {
|
||||
sb.append("choices");
|
||||
for (Choice choice: mode.getChoices()) {
|
||||
sb.append(choice.getChoice());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (ExileZone zone: exile.getExileZones()) {
|
||||
sb.append("exile").append(zone.getName()).append(zone.getValue(game));
|
||||
}
|
||||
|
||||
sb.append("combat");
|
||||
for (CombatGroup group: combat.getGroups()) {
|
||||
sb.append(group.getDefenderId()).append(group.getAttackers()).append(group.getBlockers());
|
||||
}
|
||||
|
||||
return sb.toString();
|
||||
}
|
||||
|
||||
public String getValue(Game game, UUID playerId) {
|
||||
StringBuilder sb = threadLocalBuilder.get();
|
||||
|
||||
sb.append(turn.getValue(turnNum));
|
||||
sb.append(activePlayerId).append(priorityPlayerId);
|
||||
|
||||
for (Player player: players.values()) {
|
||||
sb.append("player").append(player.isPassed()).append(player.getLife()).append("hand");
|
||||
if (playerId == player.getId()) {
|
||||
sb.append(player.getHand().getValue(game));
|
||||
}
|
||||
else {
|
||||
|
|
Loading…
Reference in a new issue