source: code/trunk/contrib/znc-import.go@ 507

Last change on this file since 507 was 421, checked in by contact, 5 years ago

Switch DB API to user IDs

This commit changes the Network schema to use user IDs instead of
usernames. While at it, a new UNIQUE(user, name) constraint ensures
there is no conflict with custom network names.

Closes: https://todo.sr.ht/~emersion/soju/86
References: https://todo.sr.ht/~emersion/soju/29

File size: 10.1 KB
Line 
1package main
2
3import (
4 "bufio"
5 "flag"
6 "fmt"
7 "io"
8 "log"
9 "net/url"
10 "os"
11 "strings"
12 "unicode"
13
14 "git.sr.ht/~emersion/soju"
15 "git.sr.ht/~emersion/soju/config"
16)
17
18const usage = `usage: znc-import [options...] <znc config path>
19
20Imports configuration from a ZNC file. Users and networks are merged if they
21already exist in the soju database. ZNC settings overwrite existing soju
22settings.
23
24Options:
25
26 -help Show this help message
27 -config <path> Path to soju config file
28 -user <username> Limit import to username (may be specified multiple times)
29 -network <name> Limit import to network (may be specified multiple times)
30`
31
32func init() {
33 flag.Usage = func() {
34 fmt.Fprintf(flag.CommandLine.Output(), usage)
35 }
36}
37
38func main() {
39 var configPath string
40 users := make(map[string]bool)
41 networks := make(map[string]bool)
42 flag.StringVar(&configPath, "config", "", "path to configuration file")
43 flag.Var((*stringSetFlag)(&users), "user", "")
44 flag.Var((*stringSetFlag)(&networks), "network", "")
45 flag.Parse()
46
47 zncPath := flag.Arg(0)
48 if zncPath == "" {
49 flag.Usage()
50 os.Exit(1)
51 }
52
53 var cfg *config.Server
54 if configPath != "" {
55 var err error
56 cfg, err = config.Load(configPath)
57 if err != nil {
58 log.Fatalf("failed to load config file: %v", err)
59 }
60 } else {
61 cfg = config.Defaults()
62 }
63
64 db, err := soju.OpenSQLDB(cfg.SQLDriver, cfg.SQLSource)
65 if err != nil {
66 log.Fatalf("failed to open database: %v", err)
67 }
68 defer db.Close()
69
70 f, err := os.Open(zncPath)
71 if err != nil {
72 log.Fatalf("failed to open ZNC configuration file: %v", err)
73 }
74 defer f.Close()
75
76 zp := zncParser{bufio.NewReader(f), 1}
77 root, err := zp.sectionBody("", "")
78 if err != nil {
79 log.Fatalf("failed to parse %q: line %v: %v", zncPath, zp.line, err)
80 }
81
82 l, err := db.ListUsers()
83 if err != nil {
84 log.Fatalf("failed to list users in DB: %v", err)
85 }
86 existingUsers := make(map[string]*soju.User, len(l))
87 for i, u := range l {
88 existingUsers[u.Username] = &l[i]
89 }
90
91 usersCreated := 0
92 usersImported := 0
93 networksImported := 0
94 channelsImported := 0
95 root.ForEach("User", func(section *zncSection) {
96 username := section.Name
97 if len(users) > 0 && !users[username] {
98 return
99 }
100 usersImported++
101
102 u, ok := existingUsers[username]
103 if ok {
104 log.Printf("user %q: updating existing user", username)
105 } else {
106 // "!!" is an invalid crypt format, thus disables password auth
107 u = &soju.User{Username: username, Password: "!!"}
108 usersCreated++
109 log.Printf("user %q: creating new user", username)
110 }
111
112 u.Admin = section.Values.Get("Admin") == "true"
113
114 if err := db.StoreUser(u); err != nil {
115 log.Fatalf("failed to store user %q: %v", username, err)
116 }
117 userID := u.ID
118
119 l, err := db.ListNetworks(userID)
120 if err != nil {
121 log.Fatalf("failed to list networks for user %q: %v", username, err)
122 }
123 existingNetworks := make(map[string]*soju.Network, len(l))
124 for i, n := range l {
125 existingNetworks[n.GetName()] = &l[i]
126 }
127
128 nick := section.Values.Get("Nick")
129 realname := section.Values.Get("RealName")
130 ident := section.Values.Get("Ident")
131
132 section.ForEach("Network", func(section *zncSection) {
133 netName := section.Name
134 if len(networks) > 0 && !networks[netName] {
135 return
136 }
137 networksImported++
138
139 logPrefix := fmt.Sprintf("user %q: network %q: ", username, netName)
140 logger := log.New(os.Stderr, logPrefix, log.LstdFlags|log.Lmsgprefix)
141
142 netNick := section.Values.Get("Nick")
143 if netNick == "" {
144 netNick = nick
145 }
146 netRealname := section.Values.Get("RealName")
147 if netRealname == "" {
148 netRealname = realname
149 }
150 netIdent := section.Values.Get("Ident")
151 if netIdent == "" {
152 netIdent = ident
153 }
154
155 for _, name := range section.Values["LoadModule"] {
156 switch name {
157 case "sasl":
158 logger.Printf("warning: SASL credentials not imported")
159 case "nickserv":
160 logger.Printf("warning: NickServ credentials not imported")
161 case "perform":
162 logger.Printf("warning: \"perform\" plugin commands not imported")
163 }
164 }
165
166 u, pass, err := importNetworkServer(section.Values.Get("Server"))
167 if err != nil {
168 logger.Fatalf("failed to import server %q: %v", section.Values.Get("Server"), err)
169 }
170
171 n, ok := existingNetworks[netName]
172 if ok {
173 logger.Printf("updating existing network")
174 } else {
175 n = &soju.Network{Name: netName}
176 logger.Printf("creating new network")
177 }
178
179 n.Addr = u.String()
180 n.Nick = netNick
181 n.Username = netIdent
182 n.Realname = netRealname
183 n.Pass = pass
184
185 if err := db.StoreNetwork(userID, n); err != nil {
186 logger.Fatalf("failed to store network: %v", err)
187 }
188
189 l, err := db.ListChannels(n.ID)
190 if err != nil {
191 logger.Fatalf("failed to list channels: %v", err)
192 }
193 existingChannels := make(map[string]*soju.Channel, len(l))
194 for i, ch := range l {
195 existingChannels[ch.Name] = &l[i]
196 }
197
198 section.ForEach("Chan", func(section *zncSection) {
199 chName := section.Name
200
201 if section.Values.Get("Disabled") == "true" {
202 logger.Printf("skipping import of disabled channel %q", chName)
203 return
204 }
205
206 channelsImported++
207
208 ch, ok := existingChannels[chName]
209 if ok {
210 logger.Printf("channel %q: updating existing channel", chName)
211 } else {
212 ch = &soju.Channel{Name: chName}
213 logger.Printf("channel %q: creating new channel", chName)
214 }
215
216 ch.Key = section.Values.Get("Key")
217 ch.Detached = section.Values.Get("Detached") == "true"
218
219 if err := db.StoreChannel(n.ID, ch); err != nil {
220 logger.Printf("channel %q: failed to store channel: %v", chName, err)
221 }
222 })
223 })
224 })
225
226 if err := db.Close(); err != nil {
227 log.Printf("failed to close database: %v", err)
228 }
229
230 if usersCreated > 0 {
231 log.Printf("warning: user passwords haven't been imported, please set them with `sojuctl change-password <username>`")
232 }
233
234 log.Printf("imported %v users, %v networks and %v channels", usersImported, networksImported, channelsImported)
235}
236
237func importNetworkServer(s string) (u *url.URL, pass string, err error) {
238 parts := strings.Fields(s)
239 if len(parts) < 2 {
240 return nil, "", fmt.Errorf("expected space-separated host and port")
241 }
242
243 scheme := "irc+insecure"
244 host := parts[0]
245 port := parts[1]
246 if strings.HasPrefix(port, "+") {
247 port = port[1:]
248 scheme = "ircs"
249 }
250
251 if len(parts) > 2 {
252 pass = parts[2]
253 }
254
255 u = &url.URL{
256 Scheme: scheme,
257 Host: host + ":" + port,
258 }
259 return u, pass, nil
260}
261
262type zncSection struct {
263 Type string
264 Name string
265 Values zncValues
266 Children []zncSection
267}
268
269func (s *zncSection) ForEach(typ string, f func(*zncSection)) {
270 for _, section := range s.Children {
271 if section.Type == typ {
272 f(&section)
273 }
274 }
275}
276
277type zncValues map[string][]string
278
279func (zv zncValues) Get(k string) string {
280 if len(zv[k]) == 0 {
281 return ""
282 }
283 return zv[k][0]
284}
285
286type zncParser struct {
287 br *bufio.Reader
288 line int
289}
290
291func (zp *zncParser) readByte() (byte, error) {
292 b, err := zp.br.ReadByte()
293 if b == '\n' {
294 zp.line++
295 }
296 return b, err
297}
298
299func (zp *zncParser) readRune() (rune, int, error) {
300 r, n, err := zp.br.ReadRune()
301 if r == '\n' {
302 zp.line++
303 }
304 return r, n, err
305}
306
307func (zp *zncParser) sectionBody(typ, name string) (*zncSection, error) {
308 section := &zncSection{Type: typ, Name: name, Values: make(zncValues)}
309
310Loop:
311 for {
312 if err := zp.skipSpace(); err != nil {
313 return nil, err
314 }
315
316 b, err := zp.br.Peek(2)
317 if err == io.EOF {
318 break
319 } else if err != nil {
320 return nil, err
321 }
322
323 switch b[0] {
324 case '<':
325 if b[1] == '/' {
326 break Loop
327 } else {
328 childType, childName, err := zp.sectionHeader()
329 if err != nil {
330 return nil, err
331 }
332 child, err := zp.sectionBody(childType, childName)
333 if err != nil {
334 return nil, err
335 }
336 if footerType, err := zp.sectionFooter(); err != nil {
337 return nil, err
338 } else if footerType != childType {
339 return nil, fmt.Errorf("invalid section footer: expected type %q, got %q", childType, footerType)
340 }
341 section.Children = append(section.Children, *child)
342 }
343 case '/':
344 if b[1] == '/' {
345 if err := zp.skipComment(); err != nil {
346 return nil, err
347 }
348 break
349 }
350 fallthrough
351 default:
352 k, v, err := zp.keyValuePair()
353 if err != nil {
354 return nil, err
355 }
356 section.Values[k] = append(section.Values[k], v)
357 }
358 }
359
360 return section, nil
361}
362
363func (zp *zncParser) skipSpace() error {
364 for {
365 r, _, err := zp.readRune()
366 if err == io.EOF {
367 return nil
368 } else if err != nil {
369 return err
370 }
371
372 if !unicode.IsSpace(r) {
373 zp.br.UnreadRune()
374 return nil
375 }
376 }
377}
378
379func (zp *zncParser) skipComment() error {
380 if err := zp.expectRune('/'); err != nil {
381 return err
382 }
383 if err := zp.expectRune('/'); err != nil {
384 return err
385 }
386
387 for {
388 b, err := zp.readByte()
389 if err == io.EOF {
390 return nil
391 } else if err != nil {
392 return err
393 }
394
395 if b == '\n' {
396 return nil
397 }
398 }
399}
400
401func (zp *zncParser) sectionHeader() (string, string, error) {
402 if err := zp.expectRune('<'); err != nil {
403 return "", "", err
404 }
405 typ, err := zp.readWord(' ')
406 if err != nil {
407 return "", "", err
408 }
409 name, err := zp.readWord('>')
410 return typ, name, err
411}
412
413func (zp *zncParser) sectionFooter() (string, error) {
414 if err := zp.expectRune('<'); err != nil {
415 return "", err
416 }
417 if err := zp.expectRune('/'); err != nil {
418 return "", err
419 }
420 return zp.readWord('>')
421}
422
423func (zp *zncParser) keyValuePair() (string, string, error) {
424 k, err := zp.readWord('=')
425 if err != nil {
426 return "", "", err
427 }
428 v, err := zp.readWord('\n')
429 return strings.TrimSpace(k), strings.TrimSpace(v), err
430}
431
432func (zp *zncParser) expectRune(expected rune) error {
433 r, _, err := zp.readRune()
434 if err != nil {
435 return err
436 } else if r != expected {
437 return fmt.Errorf("expected %q, got %q", expected, r)
438 }
439 return nil
440}
441
442func (zp *zncParser) readWord(delim byte) (string, error) {
443 var sb strings.Builder
444 for {
445 b, err := zp.readByte()
446 if err != nil {
447 return "", err
448 }
449
450 if b == delim {
451 return sb.String(), nil
452 }
453 if b == '\n' {
454 return "", fmt.Errorf("expected %q before newline", delim)
455 }
456
457 sb.WriteByte(b)
458 }
459}
460
461type stringSetFlag map[string]bool
462
463func (v *stringSetFlag) String() string {
464 return fmt.Sprint(map[string]bool(*v))
465}
466
467func (v *stringSetFlag) Set(s string) error {
468 (*v)[s] = true
469 return nil
470}
Note: See TracBrowser for help on using the repository browser.