source: code/trunk/contrib/znc-import.go@ 692

Last change on this file since 692 was 652, checked in by contact, 4 years ago

Add context args to Database interface

This is a mecanical change, which just lifts up the context.TODO()
calls from inside the DB implementations to the callers.

Future work involves properly wiring up the contexts when it makes
sense.

File size: 10.2 KB
RevLine 
[357]1package main
2
3import (
4 "bufio"
[652]5 "context"
[357]6 "flag"
7 "fmt"
8 "io"
9 "log"
10 "net/url"
11 "os"
12 "strings"
13 "unicode"
14
15 "git.sr.ht/~emersion/soju"
16 "git.sr.ht/~emersion/soju/config"
17)
18
19const usage = `usage: znc-import [options...] <znc config path>
20
21Imports configuration from a ZNC file. Users and networks are merged if they
22already exist in the soju database. ZNC settings overwrite existing soju
23settings.
24
25Options:
26
27 -help Show this help message
28 -config <path> Path to soju config file
29 -user <username> Limit import to username (may be specified multiple times)
30 -network <name> Limit import to network (may be specified multiple times)
31`
32
33func init() {
34 flag.Usage = func() {
35 fmt.Fprintf(flag.CommandLine.Output(), usage)
36 }
37}
38
39func main() {
40 var configPath string
41 users := make(map[string]bool)
42 networks := make(map[string]bool)
43 flag.StringVar(&configPath, "config", "", "path to configuration file")
44 flag.Var((*stringSetFlag)(&users), "user", "")
45 flag.Var((*stringSetFlag)(&networks), "network", "")
46 flag.Parse()
47
48 zncPath := flag.Arg(0)
49 if zncPath == "" {
50 flag.Usage()
51 os.Exit(1)
52 }
53
54 var cfg *config.Server
55 if configPath != "" {
56 var err error
57 cfg, err = config.Load(configPath)
58 if err != nil {
59 log.Fatalf("failed to load config file: %v", err)
60 }
61 } else {
62 cfg = config.Defaults()
63 }
64
[620]65 db, err := soju.OpenDB(cfg.SQLDriver, cfg.SQLSource)
[357]66 if err != nil {
67 log.Fatalf("failed to open database: %v", err)
68 }
69 defer db.Close()
70
71 f, err := os.Open(zncPath)
72 if err != nil {
73 log.Fatalf("failed to open ZNC configuration file: %v", err)
74 }
75 defer f.Close()
76
77 zp := zncParser{bufio.NewReader(f), 1}
78 root, err := zp.sectionBody("", "")
79 if err != nil {
80 log.Fatalf("failed to parse %q: line %v: %v", zncPath, zp.line, err)
81 }
82
[652]83 l, err := db.ListUsers(context.TODO())
[357]84 if err != nil {
85 log.Fatalf("failed to list users in DB: %v", err)
86 }
87 existingUsers := make(map[string]*soju.User, len(l))
88 for i, u := range l {
89 existingUsers[u.Username] = &l[i]
90 }
91
92 usersCreated := 0
93 usersImported := 0
94 networksImported := 0
95 channelsImported := 0
96 root.ForEach("User", func(section *zncSection) {
97 username := section.Name
98 if len(users) > 0 && !users[username] {
99 return
100 }
101 usersImported++
102
103 u, ok := existingUsers[username]
104 if ok {
105 log.Printf("user %q: updating existing user", username)
106 } else {
107 // "!!" is an invalid crypt format, thus disables password auth
108 u = &soju.User{Username: username, Password: "!!"}
109 usersCreated++
110 log.Printf("user %q: creating new user", username)
111 }
112
113 u.Admin = section.Values.Get("Admin") == "true"
114
[652]115 if err := db.StoreUser(context.TODO(), u); err != nil {
[357]116 log.Fatalf("failed to store user %q: %v", username, err)
117 }
[421]118 userID := u.ID
[357]119
[652]120 l, err := db.ListNetworks(context.TODO(), userID)
[357]121 if err != nil {
122 log.Fatalf("failed to list networks for user %q: %v", username, err)
123 }
124 existingNetworks := make(map[string]*soju.Network, len(l))
125 for i, n := range l {
126 existingNetworks[n.GetName()] = &l[i]
127 }
128
129 nick := section.Values.Get("Nick")
130 realname := section.Values.Get("RealName")
131 ident := section.Values.Get("Ident")
132
133 section.ForEach("Network", func(section *zncSection) {
134 netName := section.Name
135 if len(networks) > 0 && !networks[netName] {
136 return
137 }
138 networksImported++
139
140 logPrefix := fmt.Sprintf("user %q: network %q: ", username, netName)
141 logger := log.New(os.Stderr, logPrefix, log.LstdFlags|log.Lmsgprefix)
142
143 netNick := section.Values.Get("Nick")
144 if netNick == "" {
145 netNick = nick
146 }
147 netRealname := section.Values.Get("RealName")
148 if netRealname == "" {
149 netRealname = realname
150 }
151 netIdent := section.Values.Get("Ident")
152 if netIdent == "" {
153 netIdent = ident
154 }
155
156 for _, name := range section.Values["LoadModule"] {
157 switch name {
158 case "sasl":
159 logger.Printf("warning: SASL credentials not imported")
160 case "nickserv":
161 logger.Printf("warning: NickServ credentials not imported")
162 case "perform":
163 logger.Printf("warning: \"perform\" plugin commands not imported")
164 }
165 }
166
167 u, pass, err := importNetworkServer(section.Values.Get("Server"))
168 if err != nil {
169 logger.Fatalf("failed to import server %q: %v", section.Values.Get("Server"), err)
170 }
171
172 n, ok := existingNetworks[netName]
173 if ok {
174 logger.Printf("updating existing network")
175 } else {
176 n = &soju.Network{Name: netName}
177 logger.Printf("creating new network")
178 }
179
180 n.Addr = u.String()
181 n.Nick = netNick
182 n.Username = netIdent
183 n.Realname = netRealname
184 n.Pass = pass
[542]185 n.Enabled = section.Values.Get("IRCConnectEnabled") != "false"
[357]186
[652]187 if err := db.StoreNetwork(context.TODO(), userID, n); err != nil {
[357]188 logger.Fatalf("failed to store network: %v", err)
189 }
190
[652]191 l, err := db.ListChannels(context.TODO(), n.ID)
[357]192 if err != nil {
193 logger.Fatalf("failed to list channels: %v", err)
194 }
195 existingChannels := make(map[string]*soju.Channel, len(l))
196 for i, ch := range l {
197 existingChannels[ch.Name] = &l[i]
198 }
199
200 section.ForEach("Chan", func(section *zncSection) {
201 chName := section.Name
202
203 if section.Values.Get("Disabled") == "true" {
204 logger.Printf("skipping import of disabled channel %q", chName)
205 return
206 }
207
208 channelsImported++
209
210 ch, ok := existingChannels[chName]
211 if ok {
212 logger.Printf("channel %q: updating existing channel", chName)
213 } else {
214 ch = &soju.Channel{Name: chName}
215 logger.Printf("channel %q: creating new channel", chName)
216 }
217
218 ch.Key = section.Values.Get("Key")
219 ch.Detached = section.Values.Get("Detached") == "true"
220
[652]221 if err := db.StoreChannel(context.TODO(), n.ID, ch); err != nil {
[357]222 logger.Printf("channel %q: failed to store channel: %v", chName, err)
223 }
224 })
225 })
226 })
227
228 if err := db.Close(); err != nil {
229 log.Printf("failed to close database: %v", err)
230 }
231
232 if usersCreated > 0 {
233 log.Printf("warning: user passwords haven't been imported, please set them with `sojuctl change-password <username>`")
234 }
235
236 log.Printf("imported %v users, %v networks and %v channels", usersImported, networksImported, channelsImported)
237}
238
239func importNetworkServer(s string) (u *url.URL, pass string, err error) {
240 parts := strings.Fields(s)
241 if len(parts) < 2 {
242 return nil, "", fmt.Errorf("expected space-separated host and port")
243 }
244
245 scheme := "irc+insecure"
246 host := parts[0]
247 port := parts[1]
248 if strings.HasPrefix(port, "+") {
249 port = port[1:]
250 scheme = "ircs"
251 }
252
253 if len(parts) > 2 {
254 pass = parts[2]
255 }
256
257 u = &url.URL{
258 Scheme: scheme,
259 Host: host + ":" + port,
260 }
261 return u, pass, nil
262}
263
264type zncSection struct {
265 Type string
266 Name string
267 Values zncValues
268 Children []zncSection
269}
270
271func (s *zncSection) ForEach(typ string, f func(*zncSection)) {
272 for _, section := range s.Children {
273 if section.Type == typ {
274 f(&section)
275 }
276 }
277}
278
279type zncValues map[string][]string
280
281func (zv zncValues) Get(k string) string {
282 if len(zv[k]) == 0 {
283 return ""
284 }
285 return zv[k][0]
286}
287
288type zncParser struct {
289 br *bufio.Reader
290 line int
291}
292
293func (zp *zncParser) readByte() (byte, error) {
294 b, err := zp.br.ReadByte()
295 if b == '\n' {
296 zp.line++
297 }
298 return b, err
299}
300
301func (zp *zncParser) readRune() (rune, int, error) {
302 r, n, err := zp.br.ReadRune()
303 if r == '\n' {
304 zp.line++
305 }
306 return r, n, err
307}
308
309func (zp *zncParser) sectionBody(typ, name string) (*zncSection, error) {
310 section := &zncSection{Type: typ, Name: name, Values: make(zncValues)}
311
312Loop:
313 for {
314 if err := zp.skipSpace(); err != nil {
315 return nil, err
316 }
317
318 b, err := zp.br.Peek(2)
319 if err == io.EOF {
320 break
321 } else if err != nil {
322 return nil, err
323 }
324
325 switch b[0] {
326 case '<':
327 if b[1] == '/' {
328 break Loop
329 } else {
330 childType, childName, err := zp.sectionHeader()
331 if err != nil {
332 return nil, err
333 }
334 child, err := zp.sectionBody(childType, childName)
335 if err != nil {
336 return nil, err
337 }
338 if footerType, err := zp.sectionFooter(); err != nil {
339 return nil, err
340 } else if footerType != childType {
341 return nil, fmt.Errorf("invalid section footer: expected type %q, got %q", childType, footerType)
342 }
343 section.Children = append(section.Children, *child)
344 }
345 case '/':
346 if b[1] == '/' {
347 if err := zp.skipComment(); err != nil {
348 return nil, err
349 }
350 break
351 }
352 fallthrough
353 default:
354 k, v, err := zp.keyValuePair()
355 if err != nil {
356 return nil, err
357 }
358 section.Values[k] = append(section.Values[k], v)
359 }
360 }
361
362 return section, nil
363}
364
365func (zp *zncParser) skipSpace() error {
366 for {
367 r, _, err := zp.readRune()
368 if err == io.EOF {
369 return nil
370 } else if err != nil {
371 return err
372 }
373
374 if !unicode.IsSpace(r) {
375 zp.br.UnreadRune()
376 return nil
377 }
378 }
379}
380
381func (zp *zncParser) skipComment() error {
382 if err := zp.expectRune('/'); err != nil {
383 return err
384 }
385 if err := zp.expectRune('/'); err != nil {
386 return err
387 }
388
389 for {
390 b, err := zp.readByte()
391 if err == io.EOF {
392 return nil
393 } else if err != nil {
394 return err
395 }
396
397 if b == '\n' {
398 return nil
399 }
400 }
401}
402
403func (zp *zncParser) sectionHeader() (string, string, error) {
404 if err := zp.expectRune('<'); err != nil {
405 return "", "", err
406 }
407 typ, err := zp.readWord(' ')
408 if err != nil {
409 return "", "", err
410 }
411 name, err := zp.readWord('>')
412 return typ, name, err
413}
414
415func (zp *zncParser) sectionFooter() (string, error) {
416 if err := zp.expectRune('<'); err != nil {
417 return "", err
418 }
419 if err := zp.expectRune('/'); err != nil {
420 return "", err
421 }
422 return zp.readWord('>')
423}
424
425func (zp *zncParser) keyValuePair() (string, string, error) {
426 k, err := zp.readWord('=')
427 if err != nil {
428 return "", "", err
429 }
430 v, err := zp.readWord('\n')
431 return strings.TrimSpace(k), strings.TrimSpace(v), err
432}
433
434func (zp *zncParser) expectRune(expected rune) error {
435 r, _, err := zp.readRune()
436 if err != nil {
437 return err
438 } else if r != expected {
439 return fmt.Errorf("expected %q, got %q", expected, r)
440 }
441 return nil
442}
443
444func (zp *zncParser) readWord(delim byte) (string, error) {
445 var sb strings.Builder
446 for {
447 b, err := zp.readByte()
448 if err != nil {
449 return "", err
450 }
451
452 if b == delim {
453 return sb.String(), nil
454 }
455 if b == '\n' {
456 return "", fmt.Errorf("expected %q before newline", delim)
457 }
458
459 sb.WriteByte(b)
460 }
461}
462
463type stringSetFlag map[string]bool
464
465func (v *stringSetFlag) String() string {
466 return fmt.Sprint(map[string]bool(*v))
467}
468
469func (v *stringSetFlag) Set(s string) error {
470 (*v)[s] = true
471 return nil
472}
Note: See TracBrowser for help on using the repository browser.