diff --git a/main.py b/main.py index 42b8ca2..97842e1 100644 --- a/main.py +++ b/main.py @@ -10,6 +10,7 @@ from umqtt.simple import MQTTClient import ssd1306 import mcp4 +from statetree import StateTree VOLUME_MAX = 128 @@ -17,12 +18,15 @@ MQTT_KEEPALIVE = 60 MQTT_UPDATE_INTERVAL = 60 -state = { - "volume": { - "left": 0, - "right": 0, +state = StateTree( + { + "network": "OFF", + "volume": { + "left": 0, + "right": 0, + }, } -} +) last_update = 0 i2c = SoftI2C(sda=Pin(2), scl=Pin(16)) @@ -33,7 +37,6 @@ buf = bytearray((oled_height // 8) * oled_width) fbuf = framebuf.FrameBuffer1(buf, oled_width, oled_height) sta_if = network.WLAN(network.STA_IF) -network_status = "OFF" spi = SPI(1) cs = Pin(15, mode=Pin.OUT, value=1) @@ -62,18 +65,10 @@ def on_message(topic, msg): def loop(): - global mqtt, network_status, state, last_update + global mqtt, state, last_update - state_changed = False - PW0 = pot.read(0) - PW1 = pot.read(1) - - if PW0 != state["volume"]["left"]: - state["volume"]["left"] = PW0 - state_changed = True - if PW1 != state["volume"]["right"]: - state["volume"]["right"] = PW1 - state_changed = True + state["volume"]["left"] = pot.read(0) + state["volume"]["right"] = pot.read(1) if not sta_if.active(): print("Connecting to WiFi") @@ -81,13 +76,13 @@ def loop(): sta_if.connect(settings["wifi"]["ssid"], settings["wifi"]["password"]) if sta_if.active() and not sta_if.isconnected(): - network_status = "ACT" + state["network"] = "ACT" if sta_if.isconnected(): - if network_status != "OK": + if state["network"] != "OK": ip, _, _, _ = sta_if.ifconfig() print(f"WIFI Connected to {sta_if.config('ssid')}") print(f"IP Address: {ip}") - network_status = "OK" + state["network"] = "OK" if not mqtt: print("Starting MQTT client") mqtt = MQTTClient(mqtt_client_id, mqtt_broker, keepalive=MQTT_KEEPALIVE) @@ -166,20 +161,21 @@ def loop(): retain=True, ) - if state_changed or utime.time() - last_update >= MQTT_UPDATE_INTERVAL: - topic = f"{mqtt_prefix}/state" - payload = json.dumps(state) + if state.changed or utime.time() - last_update >= MQTT_UPDATE_INTERVAL: + topic = f"{mqtt_prefix}/state".encode() + payload = json.dumps(state.dictionary).encode() print(f"MQTT -> [{topic}] {payload}") - mqtt.publish(f"{mqtt_prefix}/status".encode(), b"online", retain=True) - mqtt.publish(topic.encode(), payload.encode(), retain=True) - + mqtt.publish(f"{mqtt_prefix}/status", b"online", retain=True) + mqtt.publish(topic, payload, retain=True) last_update = utime.time() mqtt.check_msg() - oled.fill(0) - oled.text(f"PW0: {PW0}", 0, 0) - oled.text(f"PW1: {PW1}", 0, 10) - oled.text(f"NET: {network_status}", 65, 0) - oled.show() + if state.changed: + oled.fill(0) + oled.text(f"LFT: {state['volume']['left']}", 0, 0) + oled.text(f"RGT: {state['volume']['right']}", 0, 10) + oled.text(f"NET: {state['network']}", 65, 0) + oled.show() + state.clean() while True: diff --git a/statetree.py b/statetree.py new file mode 100644 index 0000000..8499ef2 --- /dev/null +++ b/statetree.py @@ -0,0 +1,77 @@ +class StateTree: + """A dictionary-like object that tracks when values have been changed.""" + + def __init__(self, dictionary: dict = None, parent: "StateTree" = None) -> None: + """Create a new state tree. + + If a dictionary is supplied, its values will be initialized with it and + the tree will be marked as clean. + + If a parent is supplied, the parent will be marked as dirty when this + tree is modified. + + """ + self._dictionary = dictionary if dictionary else dict() + self._parent = parent + self._changed = False + + def dirty(self): + """Mark the tree as dirty. + + This is done automatically whenever an item is updated. + + """ + self._changed = True + if self._parent: + self._parent.dirty() + + def clean(self): + """Mark the tree as clean. + + Use this method to reset the changed status of the tree once after + you've reacted to it being updated. + + """ + self._changed = False + + @property + def changed(self): + """Returns whether the tree has been modified since the last time it was + marked as clean.""" + return self._changed + + @property + def dictionary(self): + """Returns the underlying dictionary.""" + return self._dictionary + + def __getitem__(self, *args, **kwargs): + """Get the value stored in a key in the tree. + + If the value is a dictionary, returns a StateTree object instead that + will notify the parent if a change is made. + + """ + o = self._dictionary.__getitem__(*args, **kwargs) + if isinstance(o, dict): + return StateTree(o, parent=self) + else: + return o + + def __setitem__(self, key, value): + """Update the value of a key in the tree. + + Marks the tree as changed if the key is new or the new value differs + from the current value. + + """ + if key not in self._dictionary or value != self._dictionary[key]: + self.dirty() + self._dictionary[key] = value + + def __repr__(self): + return "".format( + "^" if self._parent else "", + "*" if self._changed else "", + repr(self._dictionary), + ) diff --git a/tests/__init__.py b/tests/__init__.py index 5f6deed..5dee550 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -1 +1,2 @@ +from .test_statetree import * from .test_mcp4 import * diff --git a/tests/test_statetree.py b/tests/test_statetree.py new file mode 100644 index 0000000..a98cdfe --- /dev/null +++ b/tests/test_statetree.py @@ -0,0 +1,38 @@ +import unittest + +from statetree import StateTree + + +class StateTreeTests(unittest.TestCase): + def test_new_empty_tree_is_clean(self): + tree = StateTree() + self.assertEqual(dict(), tree.dictionary) + self.assertFalse(tree.changed) + + def test_setting_item_dirties_tree(self): + tree = StateTree() + tree["foo"] = "bar" + self.assertEqual({"foo": "bar"}, tree.dictionary) + self.assertTrue(tree.changed) + + def test_setting_an_equivalent_value_does_not_dirty_tree(self): + tree = StateTree({"foo": "bar"}) + tree["foo"] = "bar" + self.assertFalse(tree.changed) + + def test_setting_nested_item_dirties_parent(self): + tree = StateTree({"foo": {"bar": "baz"}}) + tree["foo"]["bar"] = "changed" + self.assertEqual({"foo": {"bar": "changed"}}, tree.dictionary) + self.assertTrue(tree.changed) + + def test_dirtying_sets_changed_status(self): + tree = StateTree() + tree.dirty() + self.assertTrue(tree.changed) + + def test_cleaning_removes_changed_status(self): + tree = StateTree() + tree.dirty() + tree.clean() + self.assertFalse(tree.changed)