Emergent generative agents
修訂 | d4e0f10afb105d6ea0b70c43cdb4d3127c60eca0 (tree) |
---|---|
時間 | 2023-05-05 05:11:41 |
作者 | Corbin <cds@corb...> |
Commiter | Corbin |
Switch to Twisted.
@@ -48,7 +48,8 @@ | ||
48 | 48 | doCheck = false; |
49 | 49 | }; |
50 | 50 | py = pkgs.python310.withPackages (ps: with ps; [ |
51 | - faiss irc llama-cpp-python sentence-transformers tokenizers transformers torch | |
51 | + faiss llama-cpp-python sentence-transformers tokenizers transformers torch | |
52 | + twisted | |
52 | 53 | ]); |
53 | 54 | rwkv = pkgs.stdenv.mkDerivation { |
54 | 55 | name = "rwkv.cpp"; |
@@ -1,6 +1,5 @@ | ||
1 | 1 | #!/usr/bin/env python3 |
2 | 2 | |
3 | -import asyncio | |
4 | 3 | from collections import deque |
5 | 4 | from datetime import datetime |
6 | 5 | import json |
@@ -8,11 +7,14 @@ import os.path | ||
8 | 7 | import re |
9 | 8 | import random |
10 | 9 | import sys |
11 | -from threading import RLock | |
10 | +from threading import Lock | |
12 | 11 | |
13 | -from irc.client_aio import AioReactor | |
14 | -from irc.bot import SingleServerIRCBot | |
15 | -from irc.strings import lower | |
12 | +from twisted.internet import reactor | |
13 | +from twisted.internet.defer import succeed | |
14 | +from twisted.internet.protocol import ClientFactory | |
15 | +from twisted.internet.task import LoopingCall | |
16 | +from twisted.internet.threads import deferToThread | |
17 | +from twisted.words.protocols.irc import IRCClient | |
16 | 18 | |
17 | 19 | from common import irc_line, Timer, SentenceIndex, breakAt |
18 | 20 | # from gens.camelid import CamelidGen |
@@ -28,203 +30,204 @@ def load_character(path): | ||
28 | 30 | with open(os.path.join(path, "character.json"), "r") as handle: |
29 | 31 | return json.load(handle) |
30 | 32 | |
31 | -logpath = sys.argv[1] | |
32 | -character = load_character(logpath) | |
33 | -startingChannels = character.pop("startingChannels") | |
34 | -title = character["title"] | |
35 | - | |
36 | 33 | MAX_NEW_TOKENS = 128 |
37 | 34 | print("~ Initializing mawrkov adapter…") |
38 | 35 | gen = MawrkovGen(MAX_NEW_TOKENS) |
36 | +# Need to protect per-gen data structures in C. | |
37 | +genLock = Lock() | |
39 | 38 | GiB = 1024 ** 3 |
40 | 39 | print("Initialized", gen.model_name, "(" + gen.model_arch + ")", |
41 | 40 | "Memory footprint:", round(gen.footprint() / GiB, 2), "GiB") |
42 | 41 | embedder = SentenceEmbed() |
43 | -thoughtPath = os.path.join(logpath, "thoughts.txt") | |
44 | -thoughtIndex = SentenceIndex.fromPath(thoughtPath, embedder) | |
45 | -print("~ Thought index:", thoughtIndex.size(), "thoughts") | |
42 | + | |
43 | +prologues = { | |
44 | + "clock": "Then I checked the time:", | |
45 | + "lbi": "Then I tried to interpret what just happened:", | |
46 | + "thoughts": "Then I thought to myself:", | |
47 | + "irc": "Then I chatted on IRC:", | |
48 | +} | |
46 | 49 | |
47 | 50 | class Mind: |
48 | - currentAgent = None | |
51 | + currentTag = None | |
49 | 52 | logits = state = None |
50 | 53 | |
51 | - def __init__(self): self.lock = RLock() | |
54 | + def switchTag(self, tag): | |
55 | + if tag == self.currentTag: return succeed(None) | |
56 | + else: | |
57 | + self.currentTag = tag | |
58 | + # Double newlines are added here. | |
59 | + return deferToThread(self.write, "\n" + prologues[tag]) | |
52 | 60 | |
53 | 61 | def overhear(self, tag, s): |
54 | - # XXX tag -> highlight agent first! | |
55 | - self.write(s) | |
62 | + d = self.switchTag(tag) | |
63 | + d.addCallback(lambda _: deferToThread(self.write, s)) | |
64 | + return d | |
56 | 65 | |
57 | 66 | def write(self, s): |
58 | - with self.lock: | |
67 | + with genLock: | |
59 | 68 | print("~ Write:", s) |
60 | 69 | # Newlines are added here. |
61 | 70 | self.logits, self.state = gen.feedForward(gen.tokenize(s + "\n"), |
62 | 71 | self.logits, self.state) |
63 | 72 | |
64 | 73 | def complete(self, s): |
65 | - with self.lock: | |
74 | + with genLock: | |
66 | 75 | print("~ Completion with prefix:", s) |
67 | 76 | completion, self.logits, self.state = gen.complete(s, self.logits, self.state) |
68 | 77 | print("~ Completion:", completion) |
69 | - line = s + completion | |
70 | - for listener in self.listeners: listener.listenTo(line) | |
71 | 78 | return completion |
72 | 79 | |
73 | - async def infer(self, tag, prefix): | |
74 | - # XXX | |
75 | - agent = ... | |
76 | - # XXX thread all writes and inference | |
77 | - self.write("") | |
78 | - self.write(agent.prologue()) | |
79 | - # XXX for line in agent.activate(): self.write(line) | |
80 | - # XXX get all of the messages matching the tag, feed them | |
81 | - for message in self.q[tag]: | |
82 | - self.write(message) | |
80 | + def infer(self, tag, prefix): | |
83 | 81 | print("~ agent prefix length (tokens):", gen.countTokens(prefix)) |
84 | - return self.complete(prefix) | |
82 | + d = self.switchTag(tag) | |
83 | + d.addCallback(lambda _: deferToThread(self.complete, prefix)) | |
84 | + return d | |
85 | 85 | |
86 | 86 | |
87 | 87 | class Agent: |
88 | 88 | listeners = () |
89 | 89 | def broadcast(self, s): |
90 | - for listener in self.listeners: | |
91 | - loop.call_soon(listener.overhear, self.tag, s) | |
90 | + for listener in self.listeners: listener.overhear(self.tag, s) | |
92 | 91 | |
93 | 92 | class Clock(Agent): |
94 | 93 | tag = "clock" |
95 | - def prologue(self): return ["Then I checked the time:"] | |
96 | - async def start(self): | |
97 | - while True: | |
98 | - self.broadcast(f"The time is currently {datetime.now():%H:%M:%S, %B %d %Y}.") | |
99 | - asyncio.sleep(60.0) | |
94 | + def go(self): | |
95 | + self.broadcast(f"The time is currently {datetime.now():%H:%M:%S, %B %d %Y}.") | |
100 | 96 | |
101 | 97 | class LeftBrainInterpreter(Agent): |
102 | 98 | tag = "lbi" |
103 | 99 | events = 0 |
104 | - def prologue(self): | |
105 | - return ["Then I tried to interpret what just happened:"] | |
106 | - def overhear(self, tag, s): self.events += 1 | |
107 | - | |
108 | -IRC_LINE_HEAD = re.compile(r"\d{1,2}:\d{1,2}:\d{1,2} <") | |
109 | -def breakIRCLine(line): | |
110 | - return IRC_LINE_HEAD.split(breakAt(line.strip(), "\n"), maxsplit=1)[0] | |
111 | - | |
112 | -class IRCAgent(Agent, SingleServerIRCBot): | |
113 | - reactor_class = AioReactor | |
114 | - | |
115 | - tag = "irc", None | |
116 | - def __init__(self, host, title, startingChannels): | |
117 | - super(IRCAgent, self).__init__([(host, 6667)], title_to_nick(title), title) | |
118 | - self.startingChannels = startingChannels | |
119 | - | |
120 | - def prologue(self): return ["There is activity in IRC."] | |
121 | - | |
122 | - def highlightChannel(self, channel): | |
123 | - if self.highlightedChannel != channel: | |
124 | - self.highlightedChannel = channel | |
125 | - self.log(f"In channel {channel}:") | |
126 | - | |
127 | - def prefix(self): | |
128 | - return f"{datetime.now():%H:%M:%S} <{self.connection.get_nickname()}>" | |
129 | - | |
130 | - def log(self, s): | |
131 | - # XXX channel? | |
132 | - self.broadcast(s) | |
133 | - | |
134 | - def on_join(self, c, e): | |
135 | - who = e.source.nick | |
136 | - channel = e.target | |
137 | - if who == c.get_nickname(): c.topic(channel) | |
138 | - self.log(f"{who} joins {channel}") | |
139 | - | |
140 | - def on_part(self, c, e): | |
141 | - who = e.source.nick | |
142 | - channel = e.target | |
143 | - self.log(f"{who} leaves {channel}") | |
144 | - | |
145 | - def on_currenttopic(self, c, e): | |
146 | - channel = e.arguments[0] | |
147 | - topic = e.arguments[1] | |
148 | - self.channels[channel].topic = topic | |
149 | - self.log(f"Topic for {channel} is now: {topic}") | |
150 | - | |
151 | - def on_welcome(self, c, e): | |
152 | - for channel in self.startingChannels: c.join(channel) | |
153 | - | |
154 | - def on_pubmsg(self, c, e): | |
155 | - line = e.arguments[0] | |
156 | - channel = e.target | |
157 | - self.highlightChannel(channel) | |
158 | - self.log(irc_line(datetime.now(), e.source.nick, line)) | |
159 | - # Vaguely inspired by | |
160 | - # https://github.com/jaraco/irc/blob/main/scripts/testbot.py | |
161 | - nick = lower(self.connection.get_nickname()) | |
162 | - lowered = lower(line) | |
163 | - if nick in lowered: | |
164 | - print("~ Will respond on IRC") | |
165 | - # XXX channel? | |
166 | - line = breakIRCLine(self.mind.infer()).strip() | |
167 | - self.connection.privmsg(channel, line) | |
168 | - | |
100 | + def __init__(self, mind): self.mind = mind | |
101 | + def overhear(self, tag, s): | |
102 | + self.events += 1 | |
103 | + if self.events >= 10: | |
104 | + self.events = 0 | |
105 | + return self.mind.infer(self.tag, "") | |
169 | 106 | |
170 | 107 | class ChainOfThoughts(Agent): |
171 | - def __init__(self, index, seed): | |
172 | - super(ChainOfThoughts, self).__init__() | |
108 | + tag = "thoughts" | |
109 | + def __init__(self, mind, index, seed): | |
110 | + self.mind = mind | |
173 | 111 | self.index = index |
174 | 112 | self.recentThoughts = deque([seed], maxlen=5) |
175 | - self.q = [] | |
176 | 113 | |
177 | - async def start(self): | |
178 | - while True: | |
179 | - cb = self.reflect if random.choice([0, 1]) else self.cogitate | |
180 | - await cb() | |
181 | - asyncio.sleep(30.0) | |
182 | - | |
183 | - def prologue(self): return ["Thinking to myself:"] | |
184 | - def prefix(self): return "" | |
114 | + def go(self): | |
115 | + cb = self.reflect if random.choice([0, 1]) else self.cogitate | |
116 | + return cb() | |
185 | 117 | |
186 | 118 | def addRelatedThoughts(self, s): |
187 | - thoughts = thoughtIndex.search(s, 2) | |
119 | + thoughts = self.index.search(s, 2) | |
188 | 120 | for thought in thoughts: |
189 | 121 | if thought not in self.recentThoughts: |
190 | 122 | print("~ New relevant thought:", thought) |
191 | 123 | self.recentThoughts.append(thought) |
192 | - self.q.append(thought) | |
124 | + self.broadcast(thought) | |
193 | 125 | |
194 | - listenTo = addRelatedThoughts | |
126 | + def cogitate(self): self.addRelatedThoughts(self.recentThoughts[-1]) | |
127 | + def reflect(self): | |
128 | + d = self.mind.infer(self.tag, "") | |
195 | 129 | |
196 | - async def cogitate(self): | |
197 | - self.addRelatedThoughts(self.recentThoughts[-1]) | |
198 | - if self.q: | |
199 | - for line in self.q: self.broadcast(line) | |
200 | - self.q = [] | |
130 | + @d.addCallback | |
131 | + def cb(s): | |
132 | + if not s.strip(): | |
133 | + self.broadcast(random.choice([ | |
134 | + "Head empty; no thoughts.", | |
135 | + "So bored.", | |
136 | + "Zoned out.", | |
137 | + ])) | |
201 | 138 | |
202 | - async def reflect(self): | |
203 | - # XXX no prefix, just complete | |
204 | - self.mind.highlightAgent(self) | |
205 | - thought = await self.mind.complete(self.tag, "") | |
206 | - print("~ Reflection:", thought) | |
207 | 139 | |
208 | -mind = Mind() | |
209 | -firstStatement = f"I am {title}." | |
210 | -with Timer("initial warmup"): | |
211 | - mind.logits, mind.state = gen.feedForward(gen.tokenize(firstStatement), None, None) | |
140 | +IRC_LINE_HEAD = re.compile(r"\d{1,2}:\d{1,2}:\d{1,2} <") | |
141 | +def breakIRCLine(line): | |
142 | + return IRC_LINE_HEAD.split(breakAt(line.strip(), "\n"), maxsplit=1)[0] | |
212 | 143 | |
213 | -agent = IRCAgent("june.local", title, startingChannels) | |
214 | -loop = agent.reactor.loop | |
144 | +class IRCAgent(Agent, IRCClient): | |
145 | + tag = "irc" | |
146 | + def __init__(self, mind, title, startingChannels): | |
147 | + super(IRCAgent, self).__init__() | |
148 | + self.mind = mind | |
149 | + self.nickname = title_to_nick(title) | |
150 | + self.startingChannels = startingChannels | |
215 | 151 | |
216 | -clock = Clock() | |
217 | -lbi = LeftBrainInterpreter() | |
218 | -thoughts = ChainOfThoughts(thoughtIndex, firstStatement) | |
219 | -clock.listeners = lbi, mind | |
220 | -agent.listeners = lbi, mind | |
221 | -thoughts.listeners = lbi, mind | |
152 | + def prefix(self, channel): | |
153 | + return f"{datetime.now():%H:%M:%S} {channel} <{self.nickname}>" | |
154 | + | |
155 | + # def on_join(self, c, e): | |
156 | + # who = e.source.nick | |
157 | + # channel = e.target | |
158 | + # if who == c.get_nickname(): c.topic(channel) | |
159 | + # self.broadcast(f"{who} joins {channel}") | |
160 | + | |
161 | + # def on_part(self, c, e): | |
162 | + # who = e.source.nick | |
163 | + # channel = e.target | |
164 | + # self.broadcast(f"{who} leaves {channel}") | |
165 | + | |
166 | + # def on_currenttopic(self, c, e): | |
167 | + # channel = e.arguments[0] | |
168 | + # topic = e.arguments[1] | |
169 | + # self.channels[channel].topic = topic | |
170 | + # self.broadcast(f"Topic for {channel} is now: {topic}") | |
171 | + | |
172 | + def signedOn(self): | |
173 | + for channel in self.startingChannels: self.join(channel) | |
174 | + | |
175 | + def privmsg(self, user, channel, line): | |
176 | + user = user.split("!", 1)[0] | |
177 | + self.broadcast(irc_line(datetime.now(), channel, user, line)) | |
178 | + if self.nickname in line: | |
179 | + print("~ Ping on IRC:", self.nickname) | |
180 | + d = self.mind.infer("irc", self.prefix(channel)) | |
181 | + | |
182 | + @d.addCallback | |
183 | + def cb(s): | |
184 | + line = breakIRCLine(s).strip() | |
185 | + self.msg(channel, line) | |
186 | + | |
187 | +class IRCFactory(ClientFactory): | |
188 | + protocol = IRCAgent | |
189 | + def __init__(self, mind, listeners, title, startingChannels): | |
190 | + super(IRCFactory, self).__init__() | |
191 | + self.mind = mind | |
192 | + self.listeners = listeners | |
193 | + self.title = title | |
194 | + self.startingChannels = startingChannels | |
195 | + def buildProtocol(self, addr): | |
196 | + protocol = self.protocol(self.mind, self.title, self.startingChannels) | |
197 | + protocol.factory = self | |
198 | + protocol.listeners = self.listeners | |
199 | + return protocol | |
222 | 200 | |
223 | 201 | def go(): |
224 | 202 | print("~ Starting tasks…") |
225 | - asyncio.create_task(clock.start()) | |
226 | - asyncio.create_task(thoughts.start()) | |
227 | - agent.start() | |
203 | + clock = Clock() | |
204 | + LoopingCall(clock.go).start(120.0, now=False) | |
205 | + | |
206 | + for logpath in sys.argv[1:]: | |
207 | + character = load_character(logpath) | |
208 | + title = character["title"] | |
209 | + firstStatement = f"I am {title}." | |
210 | + thoughtPath = os.path.join(logpath, "thoughts.txt") | |
211 | + thoughtIndex = SentenceIndex.fromPath(thoughtPath, embedder) | |
212 | + | |
213 | + mind = Mind() | |
214 | + with Timer("initial warmup"): | |
215 | + mind.logits, mind.state = gen.feedForward(gen.tokenize(firstStatement), None, None) | |
216 | + | |
217 | + lbi = LeftBrainInterpreter(mind) | |
218 | + clock.listeners += lbi, mind | |
219 | + | |
220 | + thoughts = ChainOfThoughts(mind, thoughtIndex, firstStatement) | |
221 | + thoughts.listeners = lbi, mind | |
222 | + LoopingCall(thoughts.go).start(60.0, now=False) | |
223 | + | |
224 | + print("~ Thought index:", thoughtIndex.size(), "thoughts") | |
225 | + factory = IRCFactory(mind, (lbi, mind), | |
226 | + title, | |
227 | + character["startingChannels"]) | |
228 | + print("~ Connecting factory for:", title) | |
229 | + reactor.connectTCP("june.local", 6667, factory) | |
230 | + print("~ Starting event loop…") | |
231 | + reactor.run() | |
228 | 232 | |
229 | -print("~ Starting event loop…") | |
230 | 233 | go() |
@@ -1,4 +1,3 @@ | ||
1 | -from bisect import bisect | |
2 | 1 | from itertools import islice |
3 | 2 | import random |
4 | 3 | from time import perf_counter |
@@ -13,36 +12,8 @@ class Timer: | ||
13 | 12 | def __exit__(self, *args): |
14 | 13 | print("Timer:", self.l, "%0.02f" % (perf_counter() - self.t), "seconds") |
15 | 14 | |
16 | -def irc_line(t, speaker, entry): return f"{t:%H:%M:%S} <{speaker}> {entry}" | |
17 | - | |
18 | -class Log: | |
19 | - "Basic scrollback and context management for conversations." | |
20 | - cutoff = 0 | |
21 | - def __init__(self, l): | |
22 | - self.l = l | |
23 | - self.stamp = len(self.l) | |
24 | - | |
25 | - def raw(self, line): | |
26 | - assert "\n" not in line | |
27 | - self.l.append(line) | |
28 | - print(self.stamp, ">", line) | |
29 | - self.stamp += 1 | |
30 | - | |
31 | - def push(self, speaker, entry): self.raw(speaker + ": " + entry) | |
32 | - def irc(self, t, speaker, entry): self.raw(irc_line(t, speaker, entry)) | |
33 | - | |
34 | - def finishPrompt(self, s, prefix): | |
35 | - return self.finishPromptAtCutoff(self.cutoff, s, prefix) | |
36 | - | |
37 | - def finishPromptAtCutoff(self, cutoff, s, prefix): | |
38 | - return s + "\n".join(self.l[cutoff:]) + "\n" + prefix | |
39 | - | |
40 | - def undo(self): self.l.pop() | |
41 | - | |
42 | - def bumpCutoff(self, max_context_length, prompt_length, prompt, prefix): | |
43 | - def keyfunc(i): | |
44 | - return -prompt_length(self.finishPromptAtCutoff(i, prompt, prefix)) | |
45 | - self.cutoff = bisect(range(len(self.l)), -max_context_length, key=keyfunc) | |
15 | +def irc_line(t, channel, speaker, entry): | |
16 | + return f"{t:%H:%M:%S} {channel} <{speaker}> {entry}" | |
46 | 17 | |
47 | 18 | def parsePygmalion(response, speakers): |
48 | 19 | "Pygmalion-specific parser to work around fine-tuned quirks." |