13
13
14
14
class TestSwagBart (unittest .TestCase ):
15
15
16
- pretrained_model_name = 'Finnish-NLP/bart-small-finnish'
17
- # pretrained_model_name = 'sshleifer/bart-tiny-random'
16
+ # pretrained_model_name = 'Finnish-NLP/bart-small-finnish'
17
+ pretrained_model_name = 'sshleifer/bart-tiny-random'
18
18
19
19
def test_untrained (self ):
20
20
hidden_size = 240
@@ -45,12 +45,11 @@ def pretrained_bart_generative(self, no_cov_mat):
45
45
tokenizer = AutoTokenizer .from_pretrained (self .pretrained_model_name , clean_up_tokenization_spaces = False )
46
46
47
47
gen_config = GenerationConfig .from_model_config (model .config )
48
- logging .warning (gen_config )
49
48
gen_config .max_new_tokens = 10
50
- logging .warning (gen_config )
49
+ logging .debug (gen_config )
51
50
52
51
swag_model .swag .collect_model (model )
53
- swag_model .sample_parameters (cov = not no_cov_mat )
52
+ swag_model .sample_parameters (cov = not no_cov_mat , seed = 1234 )
54
53
# has to be updated manually when using collect_model directly
55
54
swag_model .config .cov_mat_rank = swag_model .swag .cov_mat_rank
56
55
@@ -61,29 +60,26 @@ def pretrained_bart_generative(self, no_cov_mat):
61
60
62
61
# Test generate
63
62
example = "I have no BART and I must generate"
64
- torch .manual_seed (123 )
65
63
batch = tokenizer (example , return_tensors = "pt" )
66
64
base_generated_ids = model .generate (batch ["input_ids" ], generation_config = gen_config )
67
- # max_length=20, num_beams=1, do_sample=False, early_stopping=False
68
65
base_out = tokenizer .batch_decode (base_generated_ids , skip_special_tokens = True )
69
- logging .warning (base_out )
70
66
71
67
generated_ids = swag_model .generate (batch ["input_ids" ], generation_config = gen_config )
72
68
out = tokenizer .batch_decode (generated_ids , skip_special_tokens = True )
73
- logging .warning (out )
74
69
self .assertEqual (base_out , out )
75
70
76
71
# Test saving & loading
77
72
with tempfile .TemporaryDirectory () as tempdir :
78
73
swag_model .save_pretrained (tempdir )
79
- logging .warning (os .listdir (tempdir ))
80
- with open (os .path .join (tempdir , 'config.json' ), 'r' ) as fobj :
81
- logging .warning (fobj .read ())
74
+ logging .debug (os .listdir (tempdir ))
75
+ with open (os .path .join (tempdir , 'config.json' ), 'r' , encoding = 'utf8' ) as fobj :
76
+ logging .debug (fobj .read ())
82
77
stored_model = SwagBartForConditionalGeneration .from_pretrained (tempdir ).to (device )
83
78
79
+ stored_model .sample_parameters (cov = not no_cov_mat , seed = 1234 )
84
80
stored_fwd_out = stored_model .forward (
85
81
input_ids = torch .tensor ([[3 , 14 ]]), decoder_input_ids = torch .tensor ([[1 , 2 , 4 ]]))
86
- self .assertTrue (torch .allclose (swag_fwd_out .logits , stored_fwd_out .logits ))
82
+ self .assertTrue (torch .allclose (swag_fwd_out .logits , stored_fwd_out .logits , atol = 1e-06 ))
87
83
88
84
generated_ids = stored_model .generate (batch ["input_ids" ], generation_config = gen_config )
89
85
out = tokenizer .batch_decode (generated_ids , skip_special_tokens = True )
0 commit comments