-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathdraw.py
119 lines (98 loc) · 3.67 KB
/
draw.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
#!/usr/bin/env python
# coding: utf-8
import plotly.offline as py
import plotly.graph_objs as go
import numpy as np
import collections as cl
import itertools as it
from tqdm import tqdm
import os
import argparse as ap
def _draw_scatter(all_vocabs, all_freqs, output_prefix):
colors = [(s and t) and (s < t and s / t or t / s) or 0
for s, t in all_freqs]
colors = [c and np.log(c) or 0 for c in colors]
trace = go.Scattergl(
x=[s for s, t in all_freqs],
y=[t for s, t in all_freqs],
mode='markers',
text=all_vocabs,
marker=dict(color=colors, showscale=True, colorscale='Viridis'))
layout = go.Layout(
title='Scatter plot of shared tokens',
hovermode='closest',
xaxis=dict(title='src freq', type='log', autorange=True),
yaxis=dict(title='trg freq', type='log', autorange=True))
fig = go.Figure(data=[trace], layout=layout)
py.plot(
fig, filename='{}_scatter.html'.format(output_prefix), auto_open=False)
def _draw_rate(all_vocabs, all_freqs, output_prefix):
biases = np.array(
[(s and t) and (s / t if s > t else t / s) or 0 for s, t in all_freqs])
freqs = np.array([s + t for s, t in all_freqs])
hist, bin_edges = np.histogram(
biases[biases > 0], weights=freqs[biases > 0], bins=int(max(biases)))
bin_centers = bin_edges[:-1]
t1 = go.Scatter(
x=bin_centers,
y=hist,
name='num of tokens',
mode='lines',
fill='tozeroy')
share_token_rates = np.cumsum(hist) / sum(freqs)
t2 = go.Scatter(
x=bin_centers,
y=share_token_rates,
name='share token rates',
mode='lines',
yaxis='y2')
layout = go.Layout(
title='Shared tokens rates',
xaxis=dict(title='bias', autorange=True),
yaxis=dict(title='num of tokens', type='log', autorange=True),
yaxis2=dict(
title='accumlative share token rates',
autorange=True,
side='right',
overlaying='y'))
fig = go.Figure(data=[t1, t2], layout=layout)
py.plot(
fig, filename='{}_rate.html'.format(output_prefix), auto_open=False)
def main(args):
src_freqs = cl.Counter(
w for l in tqdm(
open(args.src_fname),
desc='gen vocab from {}'.format(os.path.basename(args.src_fname)))
for w in l.strip().split())
trg_freqs = cl.Counter(
w for l in tqdm(
open(args.trg_fname),
desc='gen vocab from {}'.format(os.path.basename(args.trg_fname)))
for w in l.strip().split())
if len(src_freqs) * len(trg_freqs) == 0:
return
all_vocabs = list(src_freqs.keys() | trg_freqs.keys())
all_freqs = [(src_freqs.get(v, 0), trg_freqs.get(v, 0))
for v in all_vocabs]
if args.type == 'scatter':
_draw_scatter(all_vocabs, all_freqs, args.output_prefix)
elif args.type == 'rate':
_draw_rate(all_vocabs, all_freqs, args.output_prefix)
elif args.type == 'both':
_draw_rate(all_vocabs, all_freqs, args.output_prefix)
_draw_scatter(all_vocabs, all_freqs, args.output_prefix)
if __name__ == '__main__':
draw_parser = ap.ArgumentParser()
draw_parser.add_argument(
'src_fname', type=str, help='the source file name.')
draw_parser.add_argument(
'trg_fname', type=str, help='the target file name')
draw_parser.add_argument(
'--type',
type=str,
choices=['scatter', 'rate', 'both'],
help='whether to only draw shared tokens')
draw_parser.add_argument(
'--output_prefix', default='pref', help='output prefix.')
args = draw_parser.parse_args()
main(args)