Skip to content

Commit

Permalink
Fix infernece & typo (#602)
Browse files Browse the repository at this point in the history
  • Loading branch information
CrazyBoyM authored Jul 12, 2024
1 parent 78f5d5f commit 598b8c9
Show file tree
Hide file tree
Showing 6 changed files with 32 additions and 11 deletions.
20 changes: 17 additions & 3 deletions ppdiffusers/ppdiffusers/models/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,7 @@ def __init__(self, *args, lora_layer: Optional[LoRAConv2dLayer] = None, **kwargs
self.out_channels = self._out_channels
self.kernel_size = self._kernel_size
self.lora_layer = lora_layer
self.data_format = kwargs["data_format"]
self.data_format = kwargs.get("data_format", "NCHW")

def set_lora_layer(self, lora_layer: Optional[LoRAConv2dLayer]):
self.lora_layer = lora_layer
Expand Down Expand Up @@ -366,11 +366,25 @@ def forward(self, hidden_states: paddle.Tensor, scale: float = 1.0) -> paddle.Te
# make sure to the functional Conv2D function as otherwise torch.compile's graph will break
# see: https://github.com/huggingface/diffusers/pull/4315
return nn.functional.conv2d(
hidden_states, self.weight, self.bias, self._stride, self._padding, self._dilation, self._groups, data_format=self.data_format,
hidden_states,
self.weight,
self.bias,
self._stride,
self._padding,
self._dilation,
self._groups,
data_format=self.data_format,
)
else:
original_outputs = nn.functional.conv2d(
hidden_states, self.weight, self.bias, self._stride, self._padding, self._dilation, self._groups, data_format=self.data_format,
hidden_states,
self.weight,
self.bias,
self._stride,
self._padding,
self._dilation,
self._groups,
data_format=self.data_format,
)
return original_outputs + (scale * self.lora_layer(hidden_states))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,5 +94,6 @@ def __init__(self, config, *args, **kwargs):

def forward(self, input_ids, attention_mask):
embs = self.transformer(input_ids=input_ids, attention_mask=attention_mask)[0]
attention_mask = attention_mask.cast(embs.dtype)
embs2 = (embs * attention_mask.unsqueeze(2)).sum(axis=1) / attention_mask.sum(axis=1)[:, None]
return self.LinearTransformation(embs2), embs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

import paddle
Expand Down Expand Up @@ -461,8 +462,11 @@ def __call__(

# predict the noise residual
noise_pred = self.unet(**unet_inputs)[0]
if str(os.environ.get("FLAGS_model_return_data")).lower() in ("true", "1")::
print(f"StableDiffusion infer: step {i+1} , origin output {noise_pred.abs().numpy().mean()} ", flush=True)
if str(os.environ.get("FLAGS_model_return_data")).lower() in ("true", "1"):
print(
f"StableDiffusion infer: step {i+1} , origin output {noise_pred.abs().numpy().mean()} ",
flush=True,
)

# perform guidance
if do_classifier_free_guidance:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,9 @@ def warp_single_latent(latent, reference_flow):
coords_t0 = F.interpolate(coords_t0, size=(h, w), mode="bilinear")
coords_t0 = paddle.transpose(coords_t0, (0, 2, 3, 1))

warped = grid_sample(latent, coords_t0, mode="nearest", padding_mode="reflection")
warped = grid_sample(
latent.cast("float32").cpu(), coords_t0.cast("float32").cpu(), mode="nearest", padding_mode="reflection"
)
return warped


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -335,7 +335,7 @@ def generate_beam(
logits[is_stopped, 0] = 0
scores_sum = scores[:, None] + logits
seq_lengths[~is_stopped] += 1
scores_sum_average = scores_sum / seq_lengths[:, None]
scores_sum_average = scores_sum / seq_lengths[:, None].cast(scores_sum.dtype)
scores_sum_average, next_tokens = scores_sum_average.reshape([-1]).topk(beam_size, -1)
next_tokens_source = next_tokens // scores_sum.shape[1]
seq_lengths = seq_lengths[next_tokens_source]
Expand All @@ -344,7 +344,7 @@ def generate_beam(
tokens = tokens[next_tokens_source]
tokens = paddle.concat(x=(tokens, next_tokens), axis=1)
generated = generated[next_tokens_source]
scores = scores_sum_average * seq_lengths
scores = scores_sum_average * seq_lengths.cast(scores_sum_average.dtype)
is_stopped = paddle.cast(is_stopped, "int32") # TODO: nf
is_stopped = is_stopped[next_tokens_source]
is_stopped = paddle.cast(is_stopped, "bool")
Expand All @@ -357,7 +357,7 @@ def generate_beam(
if is_stopped.astype("bool").all():
break

scores = scores / seq_lengths
scores = scores / seq_lengths.cast(scores.dtype)
order = scores.argsort(descending=True)
# tokens tensors are already padded to max_seq_length
output_texts = [tokens[i] for i in order]
Expand Down
4 changes: 2 additions & 2 deletions ppdiffusers/ppdiffusers/utils/download_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,7 @@ def _resumable_file_manager():
)

logger.info("storing %s in cache at %s", url_to_download, file_path)
_chmod_and_move(temp_file.name, file_path)
_chmod_and_move(temp_file.name, Path(file_path))
try:
os.remove(lock_path)
except OSError:
Expand Down Expand Up @@ -380,7 +380,7 @@ def _resumable_file_manager():
)

logger.info("storing %s in cache at %s", url_to_download, blob_path)
_chmod_and_move(temp_file.name, blob_path)
_chmod_and_move(temp_file.name, Path(blob_path))
try:
os.remove(lock_path)
except OSError:
Expand Down

0 comments on commit 598b8c9

Please sign in to comment.