source: code/trunk/contrib/znc-import.go@ 771

Last change on this file since 771 was 696, checked in by contact, 4 years ago

contrib/znc-import: use background context

File size: 10.2 KB
RevLine 
[357]1package main
2
3import (
4 "bufio"
[652]5 "context"
[357]6 "flag"
7 "fmt"
8 "io"
9 "log"
10 "net/url"
11 "os"
12 "strings"
13 "unicode"
14
15 "git.sr.ht/~emersion/soju"
16 "git.sr.ht/~emersion/soju/config"
17)
18
19const usage = `usage: znc-import [options...] <znc config path>
20
21Imports configuration from a ZNC file. Users and networks are merged if they
22already exist in the soju database. ZNC settings overwrite existing soju
23settings.
24
25Options:
26
27 -help Show this help message
28 -config <path> Path to soju config file
29 -user <username> Limit import to username (may be specified multiple times)
30 -network <name> Limit import to network (may be specified multiple times)
31`
32
33func init() {
34 flag.Usage = func() {
35 fmt.Fprintf(flag.CommandLine.Output(), usage)
36 }
37}
38
39func main() {
40 var configPath string
41 users := make(map[string]bool)
42 networks := make(map[string]bool)
43 flag.StringVar(&configPath, "config", "", "path to configuration file")
44 flag.Var((*stringSetFlag)(&users), "user", "")
45 flag.Var((*stringSetFlag)(&networks), "network", "")
46 flag.Parse()
47
48 zncPath := flag.Arg(0)
49 if zncPath == "" {
50 flag.Usage()
51 os.Exit(1)
52 }
53
54 var cfg *config.Server
55 if configPath != "" {
56 var err error
57 cfg, err = config.Load(configPath)
58 if err != nil {
59 log.Fatalf("failed to load config file: %v", err)
60 }
61 } else {
62 cfg = config.Defaults()
63 }
64
[696]65 ctx := context.Background()
66
[620]67 db, err := soju.OpenDB(cfg.SQLDriver, cfg.SQLSource)
[357]68 if err != nil {
69 log.Fatalf("failed to open database: %v", err)
70 }
71 defer db.Close()
72
73 f, err := os.Open(zncPath)
74 if err != nil {
75 log.Fatalf("failed to open ZNC configuration file: %v", err)
76 }
77 defer f.Close()
78
79 zp := zncParser{bufio.NewReader(f), 1}
80 root, err := zp.sectionBody("", "")
81 if err != nil {
82 log.Fatalf("failed to parse %q: line %v: %v", zncPath, zp.line, err)
83 }
84
[696]85 l, err := db.ListUsers(ctx)
[357]86 if err != nil {
87 log.Fatalf("failed to list users in DB: %v", err)
88 }
89 existingUsers := make(map[string]*soju.User, len(l))
90 for i, u := range l {
91 existingUsers[u.Username] = &l[i]
92 }
93
94 usersCreated := 0
95 usersImported := 0
96 networksImported := 0
97 channelsImported := 0
98 root.ForEach("User", func(section *zncSection) {
99 username := section.Name
100 if len(users) > 0 && !users[username] {
101 return
102 }
103 usersImported++
104
105 u, ok := existingUsers[username]
106 if ok {
107 log.Printf("user %q: updating existing user", username)
108 } else {
109 // "!!" is an invalid crypt format, thus disables password auth
110 u = &soju.User{Username: username, Password: "!!"}
111 usersCreated++
112 log.Printf("user %q: creating new user", username)
113 }
114
115 u.Admin = section.Values.Get("Admin") == "true"
116
[696]117 if err := db.StoreUser(ctx, u); err != nil {
[357]118 log.Fatalf("failed to store user %q: %v", username, err)
119 }
[421]120 userID := u.ID
[357]121
[696]122 l, err := db.ListNetworks(ctx, userID)
[357]123 if err != nil {
124 log.Fatalf("failed to list networks for user %q: %v", username, err)
125 }
126 existingNetworks := make(map[string]*soju.Network, len(l))
127 for i, n := range l {
128 existingNetworks[n.GetName()] = &l[i]
129 }
130
131 nick := section.Values.Get("Nick")
132 realname := section.Values.Get("RealName")
133 ident := section.Values.Get("Ident")
134
135 section.ForEach("Network", func(section *zncSection) {
136 netName := section.Name
137 if len(networks) > 0 && !networks[netName] {
138 return
139 }
140 networksImported++
141
142 logPrefix := fmt.Sprintf("user %q: network %q: ", username, netName)
143 logger := log.New(os.Stderr, logPrefix, log.LstdFlags|log.Lmsgprefix)
144
145 netNick := section.Values.Get("Nick")
146 if netNick == "" {
147 netNick = nick
148 }
149 netRealname := section.Values.Get("RealName")
150 if netRealname == "" {
151 netRealname = realname
152 }
153 netIdent := section.Values.Get("Ident")
154 if netIdent == "" {
155 netIdent = ident
156 }
157
158 for _, name := range section.Values["LoadModule"] {
159 switch name {
160 case "sasl":
161 logger.Printf("warning: SASL credentials not imported")
162 case "nickserv":
163 logger.Printf("warning: NickServ credentials not imported")
164 case "perform":
165 logger.Printf("warning: \"perform\" plugin commands not imported")
166 }
167 }
168
169 u, pass, err := importNetworkServer(section.Values.Get("Server"))
170 if err != nil {
171 logger.Fatalf("failed to import server %q: %v", section.Values.Get("Server"), err)
172 }
173
174 n, ok := existingNetworks[netName]
175 if ok {
176 logger.Printf("updating existing network")
177 } else {
178 n = &soju.Network{Name: netName}
179 logger.Printf("creating new network")
180 }
181
182 n.Addr = u.String()
183 n.Nick = netNick
184 n.Username = netIdent
185 n.Realname = netRealname
186 n.Pass = pass
[542]187 n.Enabled = section.Values.Get("IRCConnectEnabled") != "false"
[357]188
[696]189 if err := db.StoreNetwork(ctx, userID, n); err != nil {
[357]190 logger.Fatalf("failed to store network: %v", err)
191 }
192
[696]193 l, err := db.ListChannels(ctx, n.ID)
[357]194 if err != nil {
195 logger.Fatalf("failed to list channels: %v", err)
196 }
197 existingChannels := make(map[string]*soju.Channel, len(l))
198 for i, ch := range l {
199 existingChannels[ch.Name] = &l[i]
200 }
201
202 section.ForEach("Chan", func(section *zncSection) {
203 chName := section.Name
204
205 if section.Values.Get("Disabled") == "true" {
206 logger.Printf("skipping import of disabled channel %q", chName)
207 return
208 }
209
210 channelsImported++
211
212 ch, ok := existingChannels[chName]
213 if ok {
214 logger.Printf("channel %q: updating existing channel", chName)
215 } else {
216 ch = &soju.Channel{Name: chName}
217 logger.Printf("channel %q: creating new channel", chName)
218 }
219
220 ch.Key = section.Values.Get("Key")
221 ch.Detached = section.Values.Get("Detached") == "true"
222
[696]223 if err := db.StoreChannel(ctx, n.ID, ch); err != nil {
[357]224 logger.Printf("channel %q: failed to store channel: %v", chName, err)
225 }
226 })
227 })
228 })
229
230 if err := db.Close(); err != nil {
231 log.Printf("failed to close database: %v", err)
232 }
233
234 if usersCreated > 0 {
235 log.Printf("warning: user passwords haven't been imported, please set them with `sojuctl change-password <username>`")
236 }
237
238 log.Printf("imported %v users, %v networks and %v channels", usersImported, networksImported, channelsImported)
239}
240
241func importNetworkServer(s string) (u *url.URL, pass string, err error) {
242 parts := strings.Fields(s)
243 if len(parts) < 2 {
244 return nil, "", fmt.Errorf("expected space-separated host and port")
245 }
246
247 scheme := "irc+insecure"
248 host := parts[0]
249 port := parts[1]
250 if strings.HasPrefix(port, "+") {
251 port = port[1:]
252 scheme = "ircs"
253 }
254
255 if len(parts) > 2 {
256 pass = parts[2]
257 }
258
259 u = &url.URL{
260 Scheme: scheme,
261 Host: host + ":" + port,
262 }
263 return u, pass, nil
264}
265
266type zncSection struct {
267 Type string
268 Name string
269 Values zncValues
270 Children []zncSection
271}
272
273func (s *zncSection) ForEach(typ string, f func(*zncSection)) {
274 for _, section := range s.Children {
275 if section.Type == typ {
276 f(&section)
277 }
278 }
279}
280
281type zncValues map[string][]string
282
283func (zv zncValues) Get(k string) string {
284 if len(zv[k]) == 0 {
285 return ""
286 }
287 return zv[k][0]
288}
289
290type zncParser struct {
291 br *bufio.Reader
292 line int
293}
294
295func (zp *zncParser) readByte() (byte, error) {
296 b, err := zp.br.ReadByte()
297 if b == '\n' {
298 zp.line++
299 }
300 return b, err
301}
302
303func (zp *zncParser) readRune() (rune, int, error) {
304 r, n, err := zp.br.ReadRune()
305 if r == '\n' {
306 zp.line++
307 }
308 return r, n, err
309}
310
311func (zp *zncParser) sectionBody(typ, name string) (*zncSection, error) {
312 section := &zncSection{Type: typ, Name: name, Values: make(zncValues)}
313
314Loop:
315 for {
316 if err := zp.skipSpace(); err != nil {
317 return nil, err
318 }
319
320 b, err := zp.br.Peek(2)
321 if err == io.EOF {
322 break
323 } else if err != nil {
324 return nil, err
325 }
326
327 switch b[0] {
328 case '<':
329 if b[1] == '/' {
330 break Loop
331 } else {
332 childType, childName, err := zp.sectionHeader()
333 if err != nil {
334 return nil, err
335 }
336 child, err := zp.sectionBody(childType, childName)
337 if err != nil {
338 return nil, err
339 }
340 if footerType, err := zp.sectionFooter(); err != nil {
341 return nil, err
342 } else if footerType != childType {
343 return nil, fmt.Errorf("invalid section footer: expected type %q, got %q", childType, footerType)
344 }
345 section.Children = append(section.Children, *child)
346 }
347 case '/':
348 if b[1] == '/' {
349 if err := zp.skipComment(); err != nil {
350 return nil, err
351 }
352 break
353 }
354 fallthrough
355 default:
356 k, v, err := zp.keyValuePair()
357 if err != nil {
358 return nil, err
359 }
360 section.Values[k] = append(section.Values[k], v)
361 }
362 }
363
364 return section, nil
365}
366
367func (zp *zncParser) skipSpace() error {
368 for {
369 r, _, err := zp.readRune()
370 if err == io.EOF {
371 return nil
372 } else if err != nil {
373 return err
374 }
375
376 if !unicode.IsSpace(r) {
377 zp.br.UnreadRune()
378 return nil
379 }
380 }
381}
382
383func (zp *zncParser) skipComment() error {
384 if err := zp.expectRune('/'); err != nil {
385 return err
386 }
387 if err := zp.expectRune('/'); err != nil {
388 return err
389 }
390
391 for {
392 b, err := zp.readByte()
393 if err == io.EOF {
394 return nil
395 } else if err != nil {
396 return err
397 }
398
399 if b == '\n' {
400 return nil
401 }
402 }
403}
404
405func (zp *zncParser) sectionHeader() (string, string, error) {
406 if err := zp.expectRune('<'); err != nil {
407 return "", "", err
408 }
409 typ, err := zp.readWord(' ')
410 if err != nil {
411 return "", "", err
412 }
413 name, err := zp.readWord('>')
414 return typ, name, err
415}
416
417func (zp *zncParser) sectionFooter() (string, error) {
418 if err := zp.expectRune('<'); err != nil {
419 return "", err
420 }
421 if err := zp.expectRune('/'); err != nil {
422 return "", err
423 }
424 return zp.readWord('>')
425}
426
427func (zp *zncParser) keyValuePair() (string, string, error) {
428 k, err := zp.readWord('=')
429 if err != nil {
430 return "", "", err
431 }
432 v, err := zp.readWord('\n')
433 return strings.TrimSpace(k), strings.TrimSpace(v), err
434}
435
436func (zp *zncParser) expectRune(expected rune) error {
437 r, _, err := zp.readRune()
438 if err != nil {
439 return err
440 } else if r != expected {
441 return fmt.Errorf("expected %q, got %q", expected, r)
442 }
443 return nil
444}
445
446func (zp *zncParser) readWord(delim byte) (string, error) {
447 var sb strings.Builder
448 for {
449 b, err := zp.readByte()
450 if err != nil {
451 return "", err
452 }
453
454 if b == delim {
455 return sb.String(), nil
456 }
457 if b == '\n' {
458 return "", fmt.Errorf("expected %q before newline", delim)
459 }
460
461 sb.WriteByte(b)
462 }
463}
464
465type stringSetFlag map[string]bool
466
467func (v *stringSetFlag) String() string {
468 return fmt.Sprint(map[string]bool(*v))
469}
470
471func (v *stringSetFlag) Set(s string) error {
472 (*v)[s] = true
473 return nil
474}
Note: See TracBrowser for help on using the repository browser.