source: code/trunk/contrib/znc-import.go@ 646

Last change on this file since 646 was 620, checked in by hubert, 4 years ago

PostgreSQL support

File size: 10.1 KB
RevLine 
[357]1package main
2
3import (
4 "bufio"
5 "flag"
6 "fmt"
7 "io"
8 "log"
9 "net/url"
10 "os"
11 "strings"
12 "unicode"
13
14 "git.sr.ht/~emersion/soju"
15 "git.sr.ht/~emersion/soju/config"
16)
17
18const usage = `usage: znc-import [options...] <znc config path>
19
20Imports configuration from a ZNC file. Users and networks are merged if they
21already exist in the soju database. ZNC settings overwrite existing soju
22settings.
23
24Options:
25
26 -help Show this help message
27 -config <path> Path to soju config file
28 -user <username> Limit import to username (may be specified multiple times)
29 -network <name> Limit import to network (may be specified multiple times)
30`
31
32func init() {
33 flag.Usage = func() {
34 fmt.Fprintf(flag.CommandLine.Output(), usage)
35 }
36}
37
38func main() {
39 var configPath string
40 users := make(map[string]bool)
41 networks := make(map[string]bool)
42 flag.StringVar(&configPath, "config", "", "path to configuration file")
43 flag.Var((*stringSetFlag)(&users), "user", "")
44 flag.Var((*stringSetFlag)(&networks), "network", "")
45 flag.Parse()
46
47 zncPath := flag.Arg(0)
48 if zncPath == "" {
49 flag.Usage()
50 os.Exit(1)
51 }
52
53 var cfg *config.Server
54 if configPath != "" {
55 var err error
56 cfg, err = config.Load(configPath)
57 if err != nil {
58 log.Fatalf("failed to load config file: %v", err)
59 }
60 } else {
61 cfg = config.Defaults()
62 }
63
[620]64 db, err := soju.OpenDB(cfg.SQLDriver, cfg.SQLSource)
[357]65 if err != nil {
66 log.Fatalf("failed to open database: %v", err)
67 }
68 defer db.Close()
69
70 f, err := os.Open(zncPath)
71 if err != nil {
72 log.Fatalf("failed to open ZNC configuration file: %v", err)
73 }
74 defer f.Close()
75
76 zp := zncParser{bufio.NewReader(f), 1}
77 root, err := zp.sectionBody("", "")
78 if err != nil {
79 log.Fatalf("failed to parse %q: line %v: %v", zncPath, zp.line, err)
80 }
81
82 l, err := db.ListUsers()
83 if err != nil {
84 log.Fatalf("failed to list users in DB: %v", err)
85 }
86 existingUsers := make(map[string]*soju.User, len(l))
87 for i, u := range l {
88 existingUsers[u.Username] = &l[i]
89 }
90
91 usersCreated := 0
92 usersImported := 0
93 networksImported := 0
94 channelsImported := 0
95 root.ForEach("User", func(section *zncSection) {
96 username := section.Name
97 if len(users) > 0 && !users[username] {
98 return
99 }
100 usersImported++
101
102 u, ok := existingUsers[username]
103 if ok {
104 log.Printf("user %q: updating existing user", username)
105 } else {
106 // "!!" is an invalid crypt format, thus disables password auth
107 u = &soju.User{Username: username, Password: "!!"}
108 usersCreated++
109 log.Printf("user %q: creating new user", username)
110 }
111
112 u.Admin = section.Values.Get("Admin") == "true"
113
114 if err := db.StoreUser(u); err != nil {
115 log.Fatalf("failed to store user %q: %v", username, err)
116 }
[421]117 userID := u.ID
[357]118
[421]119 l, err := db.ListNetworks(userID)
[357]120 if err != nil {
121 log.Fatalf("failed to list networks for user %q: %v", username, err)
122 }
123 existingNetworks := make(map[string]*soju.Network, len(l))
124 for i, n := range l {
125 existingNetworks[n.GetName()] = &l[i]
126 }
127
128 nick := section.Values.Get("Nick")
129 realname := section.Values.Get("RealName")
130 ident := section.Values.Get("Ident")
131
132 section.ForEach("Network", func(section *zncSection) {
133 netName := section.Name
134 if len(networks) > 0 && !networks[netName] {
135 return
136 }
137 networksImported++
138
139 logPrefix := fmt.Sprintf("user %q: network %q: ", username, netName)
140 logger := log.New(os.Stderr, logPrefix, log.LstdFlags|log.Lmsgprefix)
141
142 netNick := section.Values.Get("Nick")
143 if netNick == "" {
144 netNick = nick
145 }
146 netRealname := section.Values.Get("RealName")
147 if netRealname == "" {
148 netRealname = realname
149 }
150 netIdent := section.Values.Get("Ident")
151 if netIdent == "" {
152 netIdent = ident
153 }
154
155 for _, name := range section.Values["LoadModule"] {
156 switch name {
157 case "sasl":
158 logger.Printf("warning: SASL credentials not imported")
159 case "nickserv":
160 logger.Printf("warning: NickServ credentials not imported")
161 case "perform":
162 logger.Printf("warning: \"perform\" plugin commands not imported")
163 }
164 }
165
166 u, pass, err := importNetworkServer(section.Values.Get("Server"))
167 if err != nil {
168 logger.Fatalf("failed to import server %q: %v", section.Values.Get("Server"), err)
169 }
170
171 n, ok := existingNetworks[netName]
172 if ok {
173 logger.Printf("updating existing network")
174 } else {
175 n = &soju.Network{Name: netName}
176 logger.Printf("creating new network")
177 }
178
179 n.Addr = u.String()
180 n.Nick = netNick
181 n.Username = netIdent
182 n.Realname = netRealname
183 n.Pass = pass
[542]184 n.Enabled = section.Values.Get("IRCConnectEnabled") != "false"
[357]185
[421]186 if err := db.StoreNetwork(userID, n); err != nil {
[357]187 logger.Fatalf("failed to store network: %v", err)
188 }
189
190 l, err := db.ListChannels(n.ID)
191 if err != nil {
192 logger.Fatalf("failed to list channels: %v", err)
193 }
194 existingChannels := make(map[string]*soju.Channel, len(l))
195 for i, ch := range l {
196 existingChannels[ch.Name] = &l[i]
197 }
198
199 section.ForEach("Chan", func(section *zncSection) {
200 chName := section.Name
201
202 if section.Values.Get("Disabled") == "true" {
203 logger.Printf("skipping import of disabled channel %q", chName)
204 return
205 }
206
207 channelsImported++
208
209 ch, ok := existingChannels[chName]
210 if ok {
211 logger.Printf("channel %q: updating existing channel", chName)
212 } else {
213 ch = &soju.Channel{Name: chName}
214 logger.Printf("channel %q: creating new channel", chName)
215 }
216
217 ch.Key = section.Values.Get("Key")
218 ch.Detached = section.Values.Get("Detached") == "true"
219
220 if err := db.StoreChannel(n.ID, ch); err != nil {
221 logger.Printf("channel %q: failed to store channel: %v", chName, err)
222 }
223 })
224 })
225 })
226
227 if err := db.Close(); err != nil {
228 log.Printf("failed to close database: %v", err)
229 }
230
231 if usersCreated > 0 {
232 log.Printf("warning: user passwords haven't been imported, please set them with `sojuctl change-password <username>`")
233 }
234
235 log.Printf("imported %v users, %v networks and %v channels", usersImported, networksImported, channelsImported)
236}
237
238func importNetworkServer(s string) (u *url.URL, pass string, err error) {
239 parts := strings.Fields(s)
240 if len(parts) < 2 {
241 return nil, "", fmt.Errorf("expected space-separated host and port")
242 }
243
244 scheme := "irc+insecure"
245 host := parts[0]
246 port := parts[1]
247 if strings.HasPrefix(port, "+") {
248 port = port[1:]
249 scheme = "ircs"
250 }
251
252 if len(parts) > 2 {
253 pass = parts[2]
254 }
255
256 u = &url.URL{
257 Scheme: scheme,
258 Host: host + ":" + port,
259 }
260 return u, pass, nil
261}
262
263type zncSection struct {
264 Type string
265 Name string
266 Values zncValues
267 Children []zncSection
268}
269
270func (s *zncSection) ForEach(typ string, f func(*zncSection)) {
271 for _, section := range s.Children {
272 if section.Type == typ {
273 f(&section)
274 }
275 }
276}
277
278type zncValues map[string][]string
279
280func (zv zncValues) Get(k string) string {
281 if len(zv[k]) == 0 {
282 return ""
283 }
284 return zv[k][0]
285}
286
287type zncParser struct {
288 br *bufio.Reader
289 line int
290}
291
292func (zp *zncParser) readByte() (byte, error) {
293 b, err := zp.br.ReadByte()
294 if b == '\n' {
295 zp.line++
296 }
297 return b, err
298}
299
300func (zp *zncParser) readRune() (rune, int, error) {
301 r, n, err := zp.br.ReadRune()
302 if r == '\n' {
303 zp.line++
304 }
305 return r, n, err
306}
307
308func (zp *zncParser) sectionBody(typ, name string) (*zncSection, error) {
309 section := &zncSection{Type: typ, Name: name, Values: make(zncValues)}
310
311Loop:
312 for {
313 if err := zp.skipSpace(); err != nil {
314 return nil, err
315 }
316
317 b, err := zp.br.Peek(2)
318 if err == io.EOF {
319 break
320 } else if err != nil {
321 return nil, err
322 }
323
324 switch b[0] {
325 case '<':
326 if b[1] == '/' {
327 break Loop
328 } else {
329 childType, childName, err := zp.sectionHeader()
330 if err != nil {
331 return nil, err
332 }
333 child, err := zp.sectionBody(childType, childName)
334 if err != nil {
335 return nil, err
336 }
337 if footerType, err := zp.sectionFooter(); err != nil {
338 return nil, err
339 } else if footerType != childType {
340 return nil, fmt.Errorf("invalid section footer: expected type %q, got %q", childType, footerType)
341 }
342 section.Children = append(section.Children, *child)
343 }
344 case '/':
345 if b[1] == '/' {
346 if err := zp.skipComment(); err != nil {
347 return nil, err
348 }
349 break
350 }
351 fallthrough
352 default:
353 k, v, err := zp.keyValuePair()
354 if err != nil {
355 return nil, err
356 }
357 section.Values[k] = append(section.Values[k], v)
358 }
359 }
360
361 return section, nil
362}
363
364func (zp *zncParser) skipSpace() error {
365 for {
366 r, _, err := zp.readRune()
367 if err == io.EOF {
368 return nil
369 } else if err != nil {
370 return err
371 }
372
373 if !unicode.IsSpace(r) {
374 zp.br.UnreadRune()
375 return nil
376 }
377 }
378}
379
380func (zp *zncParser) skipComment() error {
381 if err := zp.expectRune('/'); err != nil {
382 return err
383 }
384 if err := zp.expectRune('/'); err != nil {
385 return err
386 }
387
388 for {
389 b, err := zp.readByte()
390 if err == io.EOF {
391 return nil
392 } else if err != nil {
393 return err
394 }
395
396 if b == '\n' {
397 return nil
398 }
399 }
400}
401
402func (zp *zncParser) sectionHeader() (string, string, error) {
403 if err := zp.expectRune('<'); err != nil {
404 return "", "", err
405 }
406 typ, err := zp.readWord(' ')
407 if err != nil {
408 return "", "", err
409 }
410 name, err := zp.readWord('>')
411 return typ, name, err
412}
413
414func (zp *zncParser) sectionFooter() (string, error) {
415 if err := zp.expectRune('<'); err != nil {
416 return "", err
417 }
418 if err := zp.expectRune('/'); err != nil {
419 return "", err
420 }
421 return zp.readWord('>')
422}
423
424func (zp *zncParser) keyValuePair() (string, string, error) {
425 k, err := zp.readWord('=')
426 if err != nil {
427 return "", "", err
428 }
429 v, err := zp.readWord('\n')
430 return strings.TrimSpace(k), strings.TrimSpace(v), err
431}
432
433func (zp *zncParser) expectRune(expected rune) error {
434 r, _, err := zp.readRune()
435 if err != nil {
436 return err
437 } else if r != expected {
438 return fmt.Errorf("expected %q, got %q", expected, r)
439 }
440 return nil
441}
442
443func (zp *zncParser) readWord(delim byte) (string, error) {
444 var sb strings.Builder
445 for {
446 b, err := zp.readByte()
447 if err != nil {
448 return "", err
449 }
450
451 if b == delim {
452 return sb.String(), nil
453 }
454 if b == '\n' {
455 return "", fmt.Errorf("expected %q before newline", delim)
456 }
457
458 sb.WriteByte(b)
459 }
460}
461
462type stringSetFlag map[string]bool
463
464func (v *stringSetFlag) String() string {
465 return fmt.Sprint(map[string]bool(*v))
466}
467
468func (v *stringSetFlag) Set(s string) error {
469 (*v)[s] = true
470 return nil
471}
Note: See TracBrowser for help on using the repository browser.