diff --git a/.github/workflows/cpp-graph-test.yml b/.github/workflows/cpp-graph-test.yml index d081a8e28..2d7332198 100644 --- a/.github/workflows/cpp-graph-test.yml +++ b/.github/workflows/cpp-graph-test.yml @@ -53,13 +53,13 @@ jobs: - name: Env build run: | - bash ${{ github.workspace }}/.github/workflows/scripts/prepare_env_with_conda.sh "cpp-graph-test" "3.8" + bash ${{ github.workspace }}/.github/workflows/scripts/prepare_env_with_conda.sh "cpp-graph-test-neural-speed" "3.8" - name: Binary build if: 0 == 1 run: | cd ${{ github.workspace }} - conda activate cpp-graph-test || source activate cpp-graph-test + conda activate cpp-graph-test-neural-speed || source activate cpp-graph-test-neural-speed pip install build --upgrade pip install -r requirements.txt python setup.py sdist bdist_wheel @@ -69,7 +69,7 @@ jobs: - name: BF16 Benchmark run: | cd ${{ github.workspace }}/.github/workflows/scripts/models - bash cpp_graph_inference.sh cpp-graph-test ${{ matrix.modelName }} ${{ env.INPUT_COMPILER_VERSION }} + bash cpp_graph_inference.sh cpp-graph-test-neural-speed ${{ matrix.modelName }} ${{ env.INPUT_COMPILER_VERSION }} - name: Rename summary run: | diff --git a/.github/workflows/format_scan.yml b/.github/workflows/format_scan.yml index 11de4f345..4297bedfd 100644 --- a/.github/workflows/format_scan.yml +++ b/.github/workflows/format_scan.yml @@ -5,9 +5,12 @@ on: branches: [main] paths: - neural_speed/** + - bestla/** + - scripts/** - setup.py - .github/workflows/format_scan.yml - .github/workflows/scripts/formatScan/** + - "!bestla/*.md" workflow_dispatch: # If there is a new commit, the previous jobs will be canceled diff --git a/.github/workflows/scripts/models/cpp_graph_inference.sh b/.github/workflows/scripts/models/cpp_graph_inference.sh index f7632830d..0ef03699b 100644 --- a/.github/workflows/scripts/models/cpp_graph_inference.sh +++ b/.github/workflows/scripts/models/cpp_graph_inference.sh @@ -62,7 +62,7 @@ function main() { if [[ "${compiler_version}" != "12.1.0" ]]; then conda install --update-deps -c conda-forge gxx==${compiler_version} gcc==${compiler_version} gxx_linux-64==${compiler_version} libstdcxx-ng sysroot_linux-64 -y fi - + export LD_LIBRARY_PATH=${HOME}/miniconda3/envs/${conda_env}/lib/:$LD_LIBRARY_PATH # compile binary cd ${working_dir} mkdir build diff --git a/neural_speed/models/requirements/baichuan.sh b/neural_speed/models/requirements/baichuan.sh new file mode 100644 index 000000000..75a88a8f9 --- /dev/null +++ b/neural_speed/models/requirements/baichuan.sh @@ -0,0 +1,19 @@ +#!/bin/bash +#=============================================================================== +# Copyright (c) 2023 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +#=============================================================================== + +# To avoid the error: 'ChatGLMTokenizer' object has no attribute 'sp_tokenizer' +pip install -r "$(dirname "${BASH_SOURCE[0]}")/common.txt" transformers==4.33.1 diff --git a/neural_speed/models/requirements/chatglm-6b.sh b/neural_speed/models/requirements/chatglm-6b.sh new file mode 100644 index 000000000..75a88a8f9 --- /dev/null +++ b/neural_speed/models/requirements/chatglm-6b.sh @@ -0,0 +1,19 @@ +#!/bin/bash +#=============================================================================== +# Copyright (c) 2023 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +#=============================================================================== + +# To avoid the error: 'ChatGLMTokenizer' object has no attribute 'sp_tokenizer' +pip install -r "$(dirname "${BASH_SOURCE[0]}")/common.txt" transformers==4.33.1 diff --git a/neural_speed/models/requirements/common.txt b/neural_speed/models/requirements/common.txt new file mode 100644 index 000000000..441da4dde --- /dev/null +++ b/neural_speed/models/requirements/common.txt @@ -0,0 +1,12 @@ +--extra-index-url https://download.pytorch.org/whl/cpu +torch==2.1.0+cpu +transformers +numpy +sentencepiece +protobuf<3.20 +einops +accelerate +peft +datasets +transformers_stream_generator +tiktoken diff --git a/neural_speed/models/requirements/mistral.txt b/neural_speed/models/requirements/mistral.txt new file mode 100644 index 000000000..786b72c27 --- /dev/null +++ b/neural_speed/models/requirements/mistral.txt @@ -0,0 +1,2 @@ +-r common.txt +transformers>=4.34.0 diff --git a/requirements.txt b/requirements.txt index f82674779..de06486c5 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,5 @@ -torch +--extra-index-url https://download.pytorch.org/whl/cpu +torch==2.1.0+cpu transformers numpy sentencepiece diff --git a/scripts/perplexity.py b/scripts/perplexity.py new file mode 100644 index 000000000..813868437 --- /dev/null +++ b/scripts/perplexity.py @@ -0,0 +1,276 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Copyright (c) 2023 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import logging +import math +import os +import pathlib +from typing import Dict, List + +import matplotlib.pyplot as plt +import torch +from tqdm import tqdm + +logging.basicConfig() +logger = logging.getLogger('perplexity') +''' +Preparing test dataset: + +>>> import datasets +>>> dataset = datasets.load_dataset("wikitext", "wikitext-2-raw-v1", split='test', num_proc=16) +>>> dataset.save_to_disk('~/wikitext-2-raw-v1-data-test') +>>> dataset = datasets.load_dataset("pg19", split='test', num_proc=16) +>>> dataset.save_to_disk('~/pg19-data-test') +''' + + +def try_resolve_dir(d): + resolved = pathlib.Path(d).expanduser().resolve() + if resolved.exists(): + return str(resolved) + return d + + +def get_ppl(sum_nll, sum_nll2, cnt: int): + ''' Get ppl and its standard deviation from sum of negative log likelihood ''' + nll = sum_nll / cnt + nll2 = sum_nll2 / cnt + ppl = math.exp(nll) + return ppl, 0. if cnt <= 1 else math.sqrt((nll2 - nll * nll) / (cnt - 1)) + + +def perplexity(model_name, dataset_name, **kwargs): + import datasets + from intel_extension_for_transformers.transformers import (AutoModelForCausalLM, WeightOnlyQuantConfig) + from transformers import AutoTokenizer, AutoConfig + model_name = try_resolve_dir(model_name) + dataset_name = try_resolve_dir(dataset_name) + + ctx_size = kwargs.get("ctx_size", 256) + prompt_size = kwargs.get("prompt_size", ctx_size // 4) # use one quarter as prompt + n_threads = kwargs.get("n_threads", len(os.sched_getaffinity(0))) # Note: linux only + n_pred_per_sample = kwargs.get("n_pred_per_sample", ctx_size * 2) + n_sampels = kwargs.get("n_sampels", 2) + data_text_concat = kwargs.get("data_text_concat", "wikitext-2-raw-v1" in dataset_name) # concat samples with `\n\n` + default_model_kwargs = {"n_batch": ctx_size, "ctx_size": ctx_size, "n_keep": 4} + + tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) + data = datasets.load_from_disk(dataset_name) + test_text = data['text'] + if data_text_concat: + test_text = ['\n\n'.join(test_text)] + + if n_sampels < 0: + n_sampels = len(test_text) + elif n_sampels > len(test_text): + logger.warning(f"Try to eval {n_sampels} samples but there are only {len(test_text)} in the dataset!") + n_sampels = len(test_text) + + test_ids = [] + with tqdm(total=n_sampels, desc="tokenizing") as pbar: + length_needed = prompt_size + n_pred_per_sample + for text in test_text: + if len(test_ids) > n_sampels: + break + ids = tokenizer(text, return_tensors="pt", max_length=length_needed, truncation=True).input_ids + if ids.shape.numel() >= length_needed: + test_ids.append(ids) + pbar.update(1) + + quantized_weight_path = kwargs.pop('quantized_weight_path', None) + if quantized_weight_path: + from intel_extension_for_transformers.llm.runtime.graph import Model + model = Model() + assert pathlib.Path(quantized_weight_path).is_file(), "Quantized weight not exist!" + model.bin_file = quantized_weight_path + model.config = AutoConfig.from_pretrained(model_name, trust_remote_code=True) + model.model_type = Model.get_model_type(model.config) + model.tokenizer = tokenizer + else: + woq_kwargs = { + k: kwargs[k] + for k in kwargs + if k in ['use_cache', 'compute_dtype', 'weight_dtype', 'scale_dtype', 'group_size', 'use_ggml'] + } + woq_config = WeightOnlyQuantConfig(**woq_kwargs) + model = AutoModelForCausalLM.from_pretrained(model_name, quantization_config=woq_config, trust_remote_code=True) + + model_kwargs = {k: kwargs[k] for k in kwargs if k in ['n_keep', 'shift_roped_k', 'memory_dtype']} + model_kwargs = {**default_model_kwargs, **model_kwargs} + + ppl_hist = [{} for _ in range(n_sampels)] # ppl_hist[i_sample][end_pos] = ppl + sum_nll = [0. for _ in range(n_sampels)] # sum of negative log likelyhood + sum_nll2 = [0. for _ in range(n_sampels)] # sum of nll square + + pbar = tqdm(range(n_pred_per_sample * n_sampels)) + for i in pbar: + i_sample = i // n_pred_per_sample + i_pred = i % n_pred_per_sample + + is_first = (i_pred == 0) + + begin_pos = 0 if is_first else i_pred + prompt_size - 1 + end_pos = i_pred + prompt_size + cur_input = test_ids[i_sample][:, begin_pos:end_pos] + cur_target: torch.Tensor = test_ids[i_sample][:, end_pos] + out = model(cur_input, threads=n_threads, reinit=is_first, **model_kwargs) + logsoftmax = torch.from_numpy(out).log_softmax(-1) + nll = logsoftmax.take_along_dim(cur_target.view(-1, 1), 1) + assert len(nll) == 1 + nll_v = -nll.flatten().tolist()[0] + sum_nll[i_sample] += nll_v + sum_nll2[i_sample] += nll_v * nll_v + + cur_ppl, cur_sd = get_ppl(sum_nll[i_sample], sum_nll2[i_sample], i_pred + 1) + msg = f"Sample {i_sample + 1} / {n_sampels}; PPL = {cur_ppl:.4f} +/- {cur_ppl * cur_sd:.5f}" + pbar.set_description(msg, False) + ppl_hist[i_sample][end_pos] = cur_ppl + + return ppl_hist + + +def add_log_ppl_line(ax: plt.Axes, ppl_data: List[Dict[int, float]], label="log PPL"): + """ Plot PPL and return xmax / ymax""" + xs = [] + ys = [] + max_pos = max(max(d.keys()) for d in ppl_data) + for i in range(max_pos + 1): + ppls = [d[i] for d in ppl_data if i in d] + if not ppls: + continue + xs.append(i) + ys.append(math.log(sum(ppls) / len(ppls))) # average over samples + ax.plot(xs, ys, label=label) + + xmax = xs[torch.argmax(torch.tensor(ys)).item()] + ymax = max(ys) + return xmax, ymax, xs, ys + + +def draw_ppl(img_path: str, ppl_data: List[Dict[int, float]], ctx_size: int, model_title: str): + fig, ax = plt.subplots() + xmax, ymax, _, _ = add_log_ppl_line(ax, ppl_data) + ax.annotate(f"max={ymax:.4f}", (xmax, ymax)) + + ctx_line = ax.axvline(ctx_size, linestyle='--', color='r') + ctx_line.set_label('KV Cache Size') + ax.set_xlabel('Context Length') + ax.set_ylabel('Log Perplexity') + ax.legend() + + ax.set_title(model_title) + fig.suptitle("Language modeling perplexity") + fig.savefig(img_path) + + print(f"Max PPL: {math.exp(ymax)}") + return fig + + +def add_quant_args(parser: argparse.ArgumentParser): + group = parser.add_argument_group('quantize config') + group.add_argument('--quantized_weight_path', + type=str, + help="path to quantized weight; other quant args will be ignored if specified", + default="") + group.add_argument('--use_cache', action="store_true", help="Use local quantized model if file exists") + group.add_argument( + "--weight_dtype", + choices=["int4", "int8"], + help="Data type of quantized weight: int4/int8 (default: int4)", + default="int4", + ) + group.add_argument( + "--alg", + type=str, + help="Quantization algorithm to use: sym/asym (default: sym)", + default="sym", + ) + group.add_argument("--group_size", type=int, help="Group size: Int (default: 32)", default=32) + group.add_argument( + "--scale_dtype", + type=str, + help="Data type of scales: bf16/fp32 (default: fp32)", + default="fp32", + ) + group.add_argument( + "--compute_dtype", + type=str, + help="Data type of Gemm computation: int8/bf16/fp32 (default: int8)", + default="int8", + ) + group.add_argument( + "--use_ggml", + action="store_true", + help="enable ggml for quantization and inference", + ) + return group + + +def add_run_args(parser: argparse.ArgumentParser): + group = parser.add_argument_group('model run config') + group.add_argument( + "--n_keep", + type=int, + help="Number of tokens to keep from the initial prompt: Int (default: 0; -1 = all)", + default=1, + ) + group.add_argument( + "--shift_roped_k", + action="store_true", + help="Use ring-buffer and thus do not re-computing after reaching ctx_size (default: False)", + ) + group.add_argument("--memory_dtype", + type=str, + help="Data type of the kv memory", + choices=['f32', 'f16', 'auto'], + default="auto") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Evaluate perplexity for a model given a dataset") + parser.add_argument('--model_name', type=str, default="~/Llama-2-7b-chat-hf") + parser.add_argument('--dataset_name', type=str, default="~/pg19-data-test") + parser.add_argument('--ctx_size', type=int, default=256) + parser.add_argument('--prompt_size', type=int) + parser.add_argument('--n_threads', type=int) + parser.add_argument('--n_pred_per_sample', type=int) + parser.add_argument('--n_sampels', type=int) + parser.add_argument('--data_text_concat', action="store_true", default=None) + parser.add_argument('--fig_path', type=str, default="out/ppl.png") + add_quant_args(parser) + add_run_args(parser) + + ns_args = parser.parse_args() + args = vars(ns_args) + args = {k: args[k] for k in args if args[k] is not None} + + pathlib.Path.mkdir(pathlib.Path("out"), exist_ok=True) + ppl_data = perplexity(**args) + + # draw the graph + job_name = f"{ns_args.model_name}-{ns_args.weight_dtype}" + if ns_args.weight_dtype != 'fp32': + job_name += f"-{ns_args.compute_dtype}-g{ns_args.group_size}" + + job_name += f"-keep{ns_args.n_keep}" + draw_ppl(ns_args.fig_path, ppl_data, ns_args.ctx_size, job_name) + + # dump raw data + import json + with open('out/ppl_data.json', 'w') as f: + json.dump(ppl_data, f, indent=2) diff --git a/tests/model-test/cpp_graph_inference.sh b/tests/model-test/cpp_graph_inference.sh index 42d835e23..72a843feb 100644 --- a/tests/model-test/cpp_graph_inference.sh +++ b/tests/model-test/cpp_graph_inference.sh @@ -114,7 +114,11 @@ function ppl_eval() { if [[ "$ppl_mf16_test" = true ]]; then memory_dtype_list+=('f16') fi - + echo "====== Prepare Env ===============" + [[ $(pip list | grep intel_extension_for_transformers | wc -l) == 0 ]] && pip install intel_extension_for_transformers + [[ $(pip list | grep datasets | wc -l) == 0 ]] && pip install datasets + [[ $(pip list | grep transformers | wc -l) == 0 ]] && pip install transformers + pip list echo "======= PPL Evaluation Start =======" for memory_dtype in ${memory_dtype_list[@]}; do for ppl_dataset in ${ppl_dataset_list[@]}; do @@ -122,7 +126,7 @@ function ppl_eval() { local ppl_task_name="$task_name-ppl-$(basename -- "$ppl_dataset")-nctx$ppl_nctx-M$memory_dtype" echo "***** PPL: $ppl_task_name *****" OMP_NUM_THREADS=$(($n_cores * 1)) numactl -m 0 -C 0-$(($n_cores * 1 - 1)) \ - python scripts/perplexity.py --model_name "$model_path" --dataset_name "$ppl_dataset" --quantized_weight_path "$quantized_weight_path" --ctx_size $ppl_nctx --n_threads $n_cores --memory_dtype $memory_dtype 2>&1 | + python $working_dir/scripts/perplexity.py --model_name "$model_path" --dataset_name "$ppl_dataset" --quantized_weight_path "$quantized_weight_path" --ctx_size $ppl_nctx --n_threads $n_cores --memory_dtype $memory_dtype 2>&1 | tee "$WORKSPACE/$ppl_task_name.log" mv out/ppl.png "$WORKSPACE/$ppl_task_name.png" mv out/ppl_data.json "$WORKSPACE/$ppl_task_name.json" @@ -221,26 +225,26 @@ function main() { elif [[ "${model}" == "chatglm-6b" ]]; then quant_script="./build/bin/quant_chatglm" convert_script="${convert_script}/convert_chatglm.py" - infer_cmd="python ./scripts/inference.py" + infer_cmd="python $working_dir/scripts/inference.py" extension=" --model_name chatglm --tokenizer $model_path" - requirements_file="scripts/requirements/chatglm-6b.sh" + requirements_file="$working_dir/neural_speed/models/requirements/chatglm-6b.sh" elif [[ "${model}" == "baichuan2-13b" ]]; then quant_script="./build/bin/quant_baichuan" convert_script="${convert_script}/convert_baichuan.py" - infer_cmd="python ./scripts/inference.py" - requirements_file="scripts/requirements/baichuan.sh" + infer_cmd="python $working_dir/scripts/inference.py" + requirements_file="$working_dir/neural_speed/models/requirements/baichuan.sh" extension=" --model_name baichuan --tokenizer $model_path" elif [[ "${model}" == "baichuan-13b" ]]; then quant_script="./build/bin/quant_baichuan" convert_script="${convert_script}/convert_baichuan.py" - infer_cmd="python ./scripts/inference.py" + infer_cmd="python $working_dir/scripts/inference.py" extension=" --model_name baichuan --tokenizer $model_path" - requirements_file="scripts/requirements/baichuan.sh" + requirements_file="$working_dir/neural_speed/models/requirements/baichuan.sh" elif [[ "${model}" == "mistral-7b" ]]; then quant_script="./build/bin/quant_mistral" convert_script="${convert_script}/convert_mistral.py" infer_cmd="./build/bin/run_mistral" - requirements_file="scripts/requirements/mistral.txt" + requirements_file="$working_dir/neural_speed/models/requirements/mistral.txt" elif [[ "${model}" == "qwen-7b" ]]; then quant_script="./build/bin/quant_qwen" convert_script="${convert_script}/convert_qwen.py" @@ -285,7 +289,7 @@ function main() { if [[ "${compiler_version}" != "12.1.0" ]]; then conda install --update-deps -c conda-forge gxx==${compiler_version} gcc==${compiler_version} gxx_linux-64==${compiler_version} libstdcxx-ng sysroot_linux-64 -y fi - + export LD_LIBRARY_PATH=${HOME}/miniconda3/envs/${conda_env}/lib/:$LD_LIBRARY_PATH # setup conda env for LLM # get cpu info diff --git a/tests/model-test/cpp_graph_prompts.json b/tests/model-test/cpp_graph_prompts.json index 9d234a677..63f52ce89 100644 --- a/tests/model-test/cpp_graph_prompts.json +++ b/tests/model-test/cpp_graph_prompts.json @@ -42,7 +42,7 @@ "llama-2-7b-chat": "llama", "mistral-7b": "llama", "chatglm2": "chinese1", - "baichuan-13b": "chinese1", + "baichuan-13b": "chinese3", "baichuan2-13b": "chinese1", "chatglm-6b": "chinese2" }, @@ -51,7 +51,8 @@ "llama": "It is done, and submitted. You can play 'Survival of the Tastiest' on Android, and on the web. Playing on the web works, but you have to simulate multiple touch for table moving and that can be a bit confusing. There is a lot I'd like to talk about. I will go through every topic, insted of making the typical what went right/wrong list. Concept Working over the theme was probably one of the hardest tasks which I had to face. Originally, I had an idea of what kind of game I wanted to develop, gameplay wise - something with a lot of enemies/actors, simple graphics, maybe set in space, controlled from a top-down view. I was confident that I could fit any theme around it. In the end, the problem with a theme like 'Evolution' in a game is that evolution is unassisted. It happens through several seemingly random mutations over time, with the most apt permutation surviving. This genetic car simulator is, in my opinion, a great example of actual evolution of a species facing a challenge. But is it a game? In a game, you need to control something to reach an objective. That control goes against what evolution is supposed to be like. If you allow the user to pick how to evolve something, it's not evolution anymore - it's the equivalent of intelligent design, the fable invented by creationists to combat the idea of evolution. Being agnostic and a Pastafarian, that's not something that rubbed me the right way. Hence, my biggest dillema when deciding what to create was not with what I wanted to create, but with what I did not. I didn't want to create an 'intelligent design' simulator and wrongly call it evolution. This is a problem, of course, every other contestant also had to face. And judging by the entries submitted, not many managed to work around it. I'd say the only real solution was through the use of artificial selection, somehow. So far, I have not seen any entry using this at its core gameplay. Alas, this is just a fun competition and after a while I decided not to be as strict with the game idea, and allowed myself to pick whatever I thought would work out. My initial idea was to create something where humanity tried to evolve to a next level but had some kind of foe trying to stop them from doing so. I kind of had this image of human souls flying in space towards a monolith or a space baby (all based in 2001: A Space Odyssey of course) but I couldn't think of compelling (read: serious) mechanics for that. Borgs were my next inspiration, as their whole hypothesis fit pretty well into the evolution theme. But how to make it work? Are you the borg, or fighting the Borg? The third and final idea came to me through my girlfriend, who somehow gave me the idea of making something about the evolution of Pasta. The more I thought about it the more it sounded like it would work, so I decided to go with it. Conversations with my inspiring co-worker Roushey (who also created the 'Mechanical Underdogs' signature logo for my intros) further matured the concept, as it involved into the idea of having individual pieces of pasta flying around and trying to evolve until they became all-powerful. A secondary idea here was that the game would work to explain how the Flying Spaghetti Monster came to exist - by evolving from a normal dinner table. So the idea evolved more or less into this: you are sitting a table. You have your own plate, with is your 'base'. There are 5 other guests at the table, each with their own plate. Your plate can spawn little pieces of pasta. You do so by 'ordering' them through a menu. Some pastas are better than others; some are faster, some are stronger. They have varying 'costs', which are debited from your credits (you start with a number of credits). Once spawned, your pastas start flying around. Their instinct is to fly to other plates, in order to conquer them (the objective of the game is having your pasta conquer all the plates on the table). But they are really autonomous, so after being spawned, you have no control over your pasta (think DotA or LoL creeps). Your pasta doesn't like other people's pasta, so if they meet, they shoot sauce at each other until one dies. You get credits for other pastas your own pasta kill. Once a pasta is in the vicinity of a plate, it starts conquering it for its team. It takes around 10 seconds for a plate to be conquered; less if more pasta from the same team are around. If pasta from other team are around, though, they get locked down in their attempt, unable to conquer the plate, until one of them die (think Battlefield's standard 'Conquest' mode). You get points every second for every plate you own. Over time, the concept also evolved to use an Italian bistro as its main scenario. Carlos, Carlos' Bistro's founder and owner Setup No major changes were made from my work setup. I used FDT and Starling creating an Adobe AIR (ActionScript) project, all tools or frameworks I already had some knowledge with. One big change for me was that I livestreamed my work through a twitch.tv account. This was a new thing for me. As recommended by Roushey, I used a program called XSplit and I got to say, it is pretty amazing. It made the livestream pretty effortless and the features are awesome, even for the free version. It was great to have some of my friends watch me, and then interact with them and random people through chat. It was also good knowing that I was also recording a local version of the files, so I could make a timelapse video later. Knowing the video was being recorded also made me a lot more self-conscious about my computer use, as if someone was watching over my shoulder. It made me realize that sometimes I spend too much time in seemingly inane tasks (I ended up wasting the longest time just to get some text alignment the way I wanted - it'll probably drive someone crazy if they watch it) and that I do way too many typos where writing code. I pretty much spend half of the time writing a line and the other half fixing the crazy characters in it. My own stream was probably boring to watch since I was coding for the most time. But livestreaming is one of the cool things to do as a spectator too. It was great seeing other people working - I had a few tabs opened on my second monitor all the time. It's actually a bit sad, because if I could, I could have spent the whole weekend just watching other people working! But I had to do my own work, so I'd only do it once in a while, when resting for a bit. Design Although I wanted some simple, low-fi, high-contrast kind of design, I ended up going with somewhat realistic (vector) art. I think it worked very well, fitting the mood of the game, but I also went overboard. For example: to know the state of a plate (who owns it, who's conquering it and how much time they have left before conquering it, which pasta units are in the queue, etc), you have to look at the plate's bill. The problem I realized when doing some tests is that people never look at the bill! They think it's some kind of prop, so they never actually read its details. Plus, if you're zoomed out too much, you can't actually read it, so it's hard to know what's going on with the game until you zoom in to the area of a specific plate. One other solution that didn't turn out to be as perfect as I thought was how to indicate who a plate base belongs to. In the game, that's indicated by the plate's decoration - its color denotes the team owner. But it's something that fits so well into the design that people never realized it, until they were told about it. In the end, the idea of going with a full physical metaphor is one that should be done with care. Things that are very important risk becoming background noise, unless the player knows its importance. Originally, I wanted to avoid any kind of heads-up display in my game. In the end, I ended up adding it at the bottom to indicate your credits and bases owned, as well as the hideous out-of-place-and-still-not-obvious 'Call Waiter' button. But in hindsight, I should have gone with a simple HUD from the start, especially one that indicated each team's colors and general state of the game without the need for zooming in and out. Development Development went fast. But not fast enough. Even though I worked around 32+ hours for this Ludum Dare, the biggest problem that I had to face in the end was overscoping.", "gptj-6b": "It is done, and submitted. You can play 'Survival of the Tastiest' on Android, and on the web. Playing on the web works, but you have to simulate multiple touch for table moving and that can be a bit confusing. There is a lot I'd like to talk about. I will go through every topic, insted of making the typical what went right/wrong list. Concept Working over the theme was probably one of the hardest tasks which I had to face. Originally, I had an idea of what kind of game I wanted to develop, gameplay wise - something with a lot of enemies/actors, simple graphics, maybe set in space, controlled from a top-down view. I was confident that I could fit any theme around it. In the end, the problem with a theme like 'Evolution' in a game is that evolution is unassisted. It happens through several seemingly random mutations over time, with the most apt permutation surviving. This genetic car simulator is, in my opinion, a great example of actual evolution of a species facing a challenge. But is it a game? In a game, you need control something to reach an objective. That control goes against what evolution is supposed to be like. If you allow the user to pick how to evolve something, it's not evolution anymore - it's the equivalent of intelligent design, the fable invented by creationists to combat the idea of evolution. Being agnostic and a Pastafarian, that's not something that rubbed me the right way. Hence, my biggest dillema when deciding what to create was not with what I wanted to create, but with what I did not. I didn't want to create an 'intelligent design' simulator and wrongly call it evolution. This is a problem, of course, every other contestant also had to face. And judging by the entries submitted, not many managed to work around it. I'd say the only real solution was through the use of artificial selection, somehow. So far, I haven't seen any entry using this at its core gameplay. Alas, this is just a fun competition and after a while I decided not to be as strict with the game idea, and allowed myself to pick whatever I thought would work out. My initial idea was to create something where humanity tried to evolve to a next level, but had some kind of foe trying to stop them from doing so. I kind of had this image of human souls flying in space towards a monolith or a space baby (all based in 2001: A Space Odyssey of course) but I couldn't think of compelling (read: serious) mechanics for that. Borgs were my next inspiration, as their whole hypothesis fit pretty well into the evolution theme. But how to make it work? Are you the borg, or fighting the Borg? The third and final idea came to me through my girlfriend, who somehow gave me the idea of making something about the evolution of Pasta. The more I thought about it the more it sounded like it would work, so I decided to go with it. Conversations with my inspiring co-worker Roushey (who also created the 'Mechanical Underdogs' signature logo for my intros) further matured the concept, as it involved into the idea of having individual pieces of pasta flying around and trying to evolve until they became all-powerful. A secondary idea here was that the game would work to explain how the Flying Spaghetti Monster came to exist - by evolving from a normal dinner table. So the idea evolved more or less into this: you are sitting a table. You have your own plate, with is your 'base'. There are 5 other guests at the table, each with their own plate. Your plate can spawn little pieces of pasta. You do so by 'ordering' them through a menu. Some pastas are better than others; some are faster, some are stronger. They have varying 'costs', which are debited from your credits (you start with a number of credits). Once spawned, your pastas start flying around. Their instinct is to fly to other plates, in order to conquer them (the objective of the game is having your pasta conquer all the plates on the table). But they are really autonomous, so after being spawned, you have no control over your pasta (think DotA or LoL creeps). Your pasta doesn't like other people's pasta, so if they meet, they shoot sauce at each other until one dies. You get credits for other pastas your own pasta kill. Once a pasta is in the vicinity of a plate, it starts conquering it for its team. It takes around 10 seconds for a plate to be conquered; less if more pasta from the same team are around. If pasta from other team are around, though, they get locked down in their attempt, unable to conquer the plate, until one of them die (think Battlefield's standard 'Conquest' mode). You get points every second for every plate you own. Over time, the concept also evolved to use an Italian bistro as its main scenario. Carlos, Carlos' Bistro's founder and owner Setup No major changes were made from my work setup. I used FDT and Starling creating an Adobe AIR (ActionScript) project, all tools or frameworks I already had some knowledge with. One big change for me was that I livestreamed my work through a twitch.tv account. This was a new thing for me. As recommended by Roushey, I used a program called XSplit and I got to say, it is pretty amazing. It made the livestream pretty effortless and the features are awesome, even for the free version. It was great to have some of my friends watch me, and then interact with them and random people through chat. It was also good knowing that I was also recording a local version of the files, so I could make a timelapse video later. Knowing the video was being recorded also made me a lot more self-conscious about my computer use, as if someone was watching over my shoulder. It made me realize that sometimes I spend too much time in seemingly inane tasks (I ended up wasting the longest time just to get some text alignment the way I wanted - it'll probably drive someone crazy if they watch it) and that I do way too many typos where writing code. I pretty much spend half of the time writing a line and the other half fixing the crazy characters in it. My own stream was probably boring to watch since I was coding for the most time. But livestreaming is one of the cool things to do as a spectator too. It was great seeing other people working - I had a few tabs opened on my second monitor all the time. It's actually a bit sad, because if I could, I could have spent the whole weekend just watching other people working! But I had to do my own work, so I'd only do it once in a while, when resting for a bit. Design Although I wanted some simple, low-fi, high-contrast kind of design, I ended up going with somewhat realistic (vector) art. I think it worked very well, fitting the mood of the game, but I also went overboard. For example: to know the state of a plate (who owns it, who's conquering it and how much time they have left before conquering it, which pasta units are in the queue, etc), you have to look at the plate's bill. The problem I realized when doing some tests is that people never look at the bill! They think it's some kind of prop, so they never actually read its details. Plus, if you're zoomed out too much, you can't actually read it, so it's hard to know what's going on with the game until you zoom in to the area of a specific plate. One other solution that didn't turn out to be as perfect as I thought was how to indicate who a plate base belongs to. In the game, that's indicated by the plate's decoration - its color denotes the team owner. But it's something that fits so well into the design that people never realized it, until they were told about it. In the end, the idea of going with a full physical metaphor is one that should be done with care. Things that are very important risk becoming background noise, unless the player knows its importance. Originally, I wanted to avoid any kind of heads-up display in my game. In the end, I ended up adding it at the bottom to indicate your credits and bases owned, as well as the hideous out-of-place-and-still-not-obvious 'Call Waiter' button. But in hindsight, I should have gone with a simple HUD from the start, especially one that indicated each team's colors and general state of the game without the need for zooming in and out. Development Development went fast. But not fast enough. Even though I worked around 32+ hours for this Ludum Dare, the biggest problem I had to face in the end was overscoping. I had too much planned, and could not get it all done. Content-wise, I had several kinds of pasta planned - Wikipedia is just amazing in that regard, split into several different groups, from small Pastina to huge Pasta al forno. But because of time constraints, I ended up scratching most of them, and ended up with 5 different types of small pasta - barely something to start when talking about the evolution of Pasta. Pastas used in the game. Unfortunately, the macs where never used Which is one of the saddest things about the project, really. It had the framework and the features to allow an endless number of elements in there, but I just did not have time to draw the rest of the assets needed (something I loved to do).", "chinese1": "\"它完成了,并提交了。你可以在Android和网络上玩“美味生存”。在网络上玩是有效的,但你必须模拟多次触摸才能移动桌子,这可能会有点令人困惑。我想谈的有很多。我将浏览每个主题,而不是列出典型的正确/错误列表。概念 围绕这个主题工作可能是我必须面对的最艰巨的任务之一。最初,我有一个想法,我想开发什么样的游戏,游戏玩法明智有很多敌人/演员的东西,简单的图形,可能设置在太空中,从自上而下的视图控制。我相信我可以围绕它适合任何主题。最后,游戏中像“进化”这样的主题的问题在于进化是无辅助的。随着时间的推移,它通过几个看似随机的突变发生,最合适的排列幸存下来。在我看来,这个基因汽车模拟器是面临挑战的物种实际进化的一个很好的例子。但这是游戏吗?在游戏中,您需要控制某些东西才能达到目标。这种控制违背了进化应该是什么样子。如果你允许用户选择如何进化某些东西,它就不再是进化了——它相当于智能设计,是创造论者发明的寓言,用来对抗进化论的想法。作为不可知论者和意大利面主义者,这不是以正确的方式摩擦我的东西。因此,在决定创作什么时,我最大的困惑不是我想创造什么,而是我不想要创造什么。我不想创建一个“智能设计”模拟器,并错误地称之为进化。这是一个问题,当然,其他参赛者也都要面对。从提交的条目来看,没有多少人设法解决这个问题。我想说,唯一真正的解决方案是通过使用人工选择,不知何故。到目前为止,我还没有看到任何条目在其核心游戏玩法中使用它。唉,这只是一个有趣的比赛,过了一段时间,我决定不那么严格地要求游戏理念,并允许自己选择我认为可行的任何内容。我最初的想法是创造一些东西,让人类试图进化到一个新的水平,但有某种敌人试图阻止他们这样做。我有点像人类灵魂在太空中飞向巨石或太空婴儿的图像(当然都是基于2001:太空漫游),但我想不出令人信服的(阅读:严肃的)机制。博格人是我的下一个灵感来源,因为他们的整个假设非常符合进化论的主题。但是如何让它工作呢?你是博格人,还是与博格人战斗?你是博格人,还是与博格人战斗?它完成了,并提交了。你可以在Android和网络上玩“美味生存”。在网络上玩是有效的,但你必须模拟多次触摸才能移动桌子,这可能会有点令人困惑。我想谈的有很多。我将浏览每个主题,而不是列出典型的正确/错误列表。概念围绕这个主题工作可能是我必须面对的最艰巨的任务之一。最初,我有一个想法,我想开发什么样的游戏,游戏玩法明智有很多敌人/演员的东西,简单的图形,可能设置在太空中,从自上而下的视图控制。我相信我可以围绕它适合任何主题。最后,游戏中像“进化”这样的主题的问题在于进化是无辅助的。随着时间的推移,它通过几个看似随机的突变发生,最合适的排列幸存下来。在我看来,这个基因汽车模拟器是面临挑战的物种实际进化的一个很好的例子。但这是游戏吗?在游戏中,您需要控制某些东西才能达到目标。这种控制违背了进化应该是什么样子。如果你允许用户选择如何进化某些东西,它就不再是进化了——它相当于智能设计,是创造论者发明的寓言,用来对抗进化论的想法。作为不可知论者和意大利面主义者,这不是以正确的方式摩擦我的东西。因此,在决定创作什么时,我最大的困惑不是我想创造什么,而是我不想要创造什么。我不想创建一个“智能设计”模拟器,并错误地称之为进化。这是一个问题,当然,其他参赛者也都要面对。从提交的条目来看,没有多少人设法解决这个问题。我想说,唯一真正的解决方案是通过使用人工选择,不知何故。到目前为止,我还没有看到任何条目在其核心游戏玩法中使用它。唉,这只是一个有趣的比赛,过了一段时间,我决定不那么严格地要求游戏理念,并允许自己选择我认为可行的任何内容。我最初的想法是创造一些东西,让人类试图进化到一个新的水平,但有某种敌人试图阻止他们这样做。我有点像人类灵魂在太空中飞向巨石或太空婴儿的图像(当然都是基于2001:太空漫游),但我想不出令人信服的(阅读:严肃的)机制。博格人是我的下一个灵感来源,因为他们的整个假设非常符合进化论的主题。但是如何让它工作呢?但是如何让它工作呢?但是如何让它工作呢?但是如何让它工作呢?但是如何让它工作呢?但是如何让它工作呢?但是如何让它工作呢?但是如何让它工作呢?但是如何让它工作呢?但是如何让它工作呢?但是如何让它工作呢?工作呢?它完成了,并提交了。你可以在Android和网络上玩“美味生存”。在网络上玩是有效的,但你必须模拟多次触摸才能移动桌子,这可能会有点令人困惑。我想谈的有很多。我将浏览每个主题,而不是列出典型的正确/错误列表。概念围绕这个主题工作可能是我必须面对的最艰巨的任务之一。最初,我有一个想法,我想开发什么样的游戏,游戏玩法明智有很多敌人/演员的东西,简单的图形,可能设置在太空中,从自上而下的视图控制。我相信我可以围绕它适合任何主题。最后,游戏中像“进化”这样的主题的问题在于进化是无辅助的。随着时间的推移,它通过几个看似随机的突变发生,最合适的排列幸存下来。在我看来,这个基因汽车模拟器是面临挑战的物种实际进化的一个很好的例子。但这是游戏吗?在游戏中,您需要控制某些东西才能达到目标。这种控制违背了进化应该是什么样子。如果你允许用户选择如何进化某些东西,它就不再是进化了——它相当于智能设计,是创造论者发明的寓言,用来对抗进化论的想法。作为不可知论者和意大利面主义者,这不是以正确的方式摩擦我的东西。因此,在决定创作什么时,我最大的困惑不是我想创造什么,而是我不想要创造什么。我不想创建一个“智能设计”模拟器,并错误地称之为进化。这是一个问题,当然,其他参赛者也都要面对。从提交的条目来看,没有多少人设法解决这个问题。我想说,唯一真正的解决方案是通过使用人工选择,不知何故。到目前为止,我还没有看到任何条目在其核心游戏玩法中使用它。唉,这只是一个有趣的比赛,过了一段时间,我决定不那么严格地要求游戏理念,并允许自己选择我认为可行的任何内容。我最初的想法是创造一些东西,让人类试图进化到一个新的水平,但有某种敌人试图阻止他们这样做。我有点像人类灵魂在太空中飞向巨石或太空婴儿的图像(当然都是基于2001:太空漫游),但我想不出令人信服的(阅读:严肃的)机制。博格人是我的下一个灵感来源,因为他们的整个假设非常符合进化论的主题。但是如何让它工作呢?你是博格人,还是与博格人战斗?你是博格人,还是与博格人战斗?它完成了,并提交了。你可以在Android和网络上玩“美味生存”。在网络上玩是有效的,但你必须模拟多次触摸才能移动桌子,这可能会有点令人困惑。我想谈的有很多。我将浏览每个主题,而不是列出典型的正确/错误列表。概念围绕这个主题工作可能是我必须面对的最艰巨的任务之一。最初,我有一个想法,我想开发什么样的游戏,游戏玩法明智有很多敌人/演员的东西,简单的图形,可能设置在太空中,从自上而下的视图控制。我相信我可以围绕它适合任何主题。最后,游戏中像“进化”这样的主题的问题在于进化是无辅助的。随着时间的推移,它通过几个看似随机的突变发生,最合适的排列幸存下来。在我看来,这个基因汽车模拟器是面临挑战的物种实际进化的一个很好的例子。但这是游戏吗?在游戏中,您需要控制某些东西才能达到目标。这种控制违背了进化应该是什么样子。如果你允许用户选择如何进化某些东西,它就不再是进化了——它相当于智能设计,是创造论者发明的寓言,用来对抗进化论的想法。作为不可知论者和意大利面主义者,这不是以正确的方式摩擦我的东西。因此,在决定创作什么时,我最大的困惑不是我想创造什么,而是我不想要创造什么。我不想创建一个“智能设计”模拟器,并错误地称之为进化。这是一个问题,当然,其他参赛者也都要面对。从提交的条目来看,没有多少人设法解决这个问题。我想说,唯一真正的解决方案是通过使用人工选择,不知何故。到目前为止,我还没有看到任何条目在其核心游戏玩法中使用它。唉,这只是一个有趣的比赛,过了一段时间,我决定不那么严格地要求游戏理念,并允许自己选择我认为可行的任何内容。我最初的想法是创造一些东西,让人类试图进化到一个新的水平,但有某种敌人试图阻止他们这样做。我有点像人类灵魂在太空中飞向巨石或太空婴儿的图像(当然都是基于2001:太空漫游),但我想不出令人信服的(阅读:严肃的)机制。博格人是我的下一个灵感来源,因为他们的整个假设非常符合进化论的主题。但是如何让它工作呢?但是如何让它工作呢?但是如何让它工作呢?但是如何让它工作呢?但是如何让它工作呢?但是如何让它工作呢?但是如何让它工作呢?但是如何让它工作呢但是工作?\"", - "chinese2": "\"它完成了,并提交了。你可以在Android和网络上玩“美味生存”。在网络上玩是有效的,但你必须模拟多次触摸才能移动桌子,这可能会有点令人困惑。我想谈的有很多。我将浏览每个主题,而不是列出典型的正确/错误列表。概念 围绕这个主题工作可能是我必须面对的最艰巨的任务之一。最初,我有一个想法,我想开发什么样的游戏,游戏玩法明智有很多敌人/演员的东西,简单的图形,可能设置在太空中,从自上而下的视图控制。我相信我可以围绕它适合任何主题。最后,游戏中像“进化”这样的主题的问题在于进化是无辅助的。随着时间的推移,它通过几个看似随机的突变发生,最合适的排列幸存下来。在我看来,这个基因汽车模拟器是面临挑战的物种实际进化的一个很好的例子。但这是游戏吗?在游戏中,您需要控制某些东西才能达到目标。这种控制违背了进化应该是什么样子。如果你允许用户选择如何进化某些东西,它就不再是进化了——它相当于智能设计,是创造论者发明的寓言,用来对抗进化论的想法。作为不可知论者和意大利面主义者,这不是以正确的方式摩擦我的东西。因此,在决定创作什么时,我最大的困惑不是我想创造什么,而是我不想要创造什么。我不想创建一个“智能设计”模拟器,并错误地称之为进化。这是一个问题,当然,其他参赛者也都要面对。从提交的条目来看,没有多少人设法解决这个问题。我想说,唯一真正的解决方案是通过使用人工选择,不知何故。到目前为止,我还没有看到任何条目在其核心游戏玩法中使用它。唉,这只是一个有趣的比赛,过了一段时间,我决定不那么严格地要求游戏理念,并允许自己选择我认为可行的任何内容。我最初的想法是创造一些东西,让人类试图进化到一个新的水平,但有某种敌人试图阻止他们这样做。我有点像人类灵魂在太空中飞向巨石或太空婴儿的图像(当然都是基于2001:太空漫游),但我想不出令人信服的(阅读:严肃的)机制。博格人是我的下一个灵感来源,因为他们的整个假设非常符合进化论的主题。但是如何让它工作呢?你是博格人,还是与博格人战斗?你是博格人,还是与博格人战斗?它完成了,并提交了。你可以在Android和网络上玩“美味生存”。在网络上玩是有效的,但你必须模拟多次触摸才能移动桌子,这可能会有点令人困惑。我想谈的有很多。我将浏览每个主题,而不是列出典型的正确/错误列表。概念围绕这个主题工作可能是我必须面对的最艰巨的任务之一。最初,我有一个想法,我想开发什么样的游戏,游戏玩法明智有很多敌人/演员的东西,简单的图形,可能设置在太空中,从自上而下的视图控制。我相信我可以围绕它适合任何主题。最后,游戏中像“进化”这样的主题的问题在于进化是无辅助的。随着时间的推移,它通过几个看似随机的突变发生,最合适的排列幸存下来。在我看来,这个基因汽车模拟器是面临挑战的物种实际进化的一个很好的例子。但这是游戏吗?在游戏中,您需要控制某些东西才能达到目标。这种控制违背了进化应该是什么样子。如果你允许用户选择如何进化某些东西,它就不再是进化了——它相当于智能设计,是创造论者发明的寓言,用来对抗进化论的想法。作为不可知论者和意大利面主义者,这不是以正确的方式摩擦我的东西。因此,在决定创作什么时,我最大的困惑不是我想创造什么,而是我不想要创造什么。我不想创建一个“智能设计”模拟器,并错误地称之为进化。这是一个问题,当然,其他参赛者也都要面对。从提交的条目来看,没有多少人设法解决这个问题。我想说,唯一真正的解决方案是通过使用人工选择,不知何故。到目前为止,我还没有看到任何条目在其核心游戏玩法中使用它。唉,这只是一个有趣的比赛,过了一段时间,我决定不那么严格地要求游戏理念,并允许自己选择我认为可行的任何内容。我最初的想法是创造一些东西,让人类试图进化到一个新的水平,但有某种敌人试图阻止他们这样做。我有点像人类灵魂在太空中飞向巨石或太空婴儿的图像(当然都是基于2001:太空漫游),但我想不出令人信服的(阅读:严肃的)机制。博格人是我的下一个灵感来源,因为他们的整个假设非常符合进化论的主题。但是如何让它工作呢?但是如何让它工作呢?但是如何让它工作呢?但是如何让它工作呢?但是如何让它工作呢?但是如何让它工作呢?但是如何让它工作呢?但是如何让它工作呢?但是如何让它工作呢?但是如何让它工作呢?但是如何让它工作呢?工作呢?它完成了,并提交了。你可以在Android和网络上玩“美味生存”。在网络上玩是有效的,但你必须模拟多次触摸才能移动桌子,这可能会有点令人困惑。我想谈的有很多。我将浏览每个主题,而不是列出典型的正确/错误列表。概念围绕这个主题工作可能是我必须面对的最艰巨的任务之一。最初,我有一个想法,我想开发什么样的游戏,游戏玩法明智有很多敌人/演员的东西,简单的图形,可能设置在太空中,从自上而下的视图控制。我相信我可以围绕它适合任何主题。最后,游戏中像“进化”这样的主题的问题在于进化是无辅助的。随着时间的推移,它通过几个看似随机的突变发生,最合适的排列幸存下来。在我看来,这个基因汽车模拟器是面临挑战的物种实际进化的一个很好的例子。但这是游戏吗?在游戏中,您需要控制某些东西才能达到目标。这种控制违背了进化应该是什么样子。如果你允许用户选择如何进化某些东西,它就不再是进化了——它相当于智能设计,是创造论者发明的寓言,用来对抗进化论的想法。作为不可知论者和意大利面主义者,这不是以正确的方式摩擦我的东西。因此,在决定创作什么时,我最大的困惑不是我想创造什么,而是我不想要创造什么。我不想创建一个“智能设计”模拟器,并错误地称之为进化。这是一个问题,当然,其他参赛者也都要面对。从提交的条目来看,没有多少人设法解决这个问题。我想说,唯一真正的解决方案是通过使用人工选择,不知何故。到目前为止,我还没有看到任何条目在其核心游戏玩法中使用它。唉,这只是一个有趣的比赛,过了一段时间,我决定不那么严格地要求游戏理念,并允许自己选择我认为可行的任何内容。我最初的想法是创造一些东西,让人类试图进化到一个新的水平,但有某种敌人试图阻止他们这样做。我有点像人类灵魂在太空中飞向巨石或太空婴儿的图像(当然都是基于2001:太空漫游),但我想不出令人信服的(阅读:严肃的)机制。博格人是我的下一个灵感来源,因为他们的整个假设非常符合进化论的主题。但是如何让它工作呢?你是博格人,还是与博格人战斗?你是博格人,还是与博格人战斗?它完成了,并提交了。你可以在Android和网络上玩“美味生存”。在网络上玩是有效的,但你必须模拟多次触摸才能移动桌子,这可能会有点令人困惑。我想谈的有很多。我将浏览每个主题,而不是列出典型的正确/错误列表。概念围绕这个主题工作可能是我必须面对的最艰巨的任务之一。最初,我有一个想法,我想开发什么样的游戏,游戏玩法明智有很多敌人/演员的东西,简单的图形,可能设置在太空中,从自上而下的视图控制。我相信我可以围绕它适合任何主题。最后,游戏中像“进化”这样的主题的问题在于进化是无辅助的。随着时间的推移,它通过几个看似随机的突变发生,最合适的排列幸存下来。在我看来,这个基因汽车模拟器是面临挑战的物种实际进化的一个很好的例子。但这是游戏吗?在游戏中,您需要控制某些东西才能达到目标。这种控制违背了进化应该是什么样子。如果你允许用户选择如何进化某些东西,它就不再是进化了——它相当于智能设计,是创造论者发明的寓言,用来对抗进化论的想法。作为不可知论者和意大利面主义者,这不是以正确的方式摩擦我的东西。因此,在决定创作什么时,我最大的困惑不是我想创造什么,而是我不想要创造什么。我不想创建一个“智能设计”模拟器,并错误地称之为进化。这是一个问题,当然,其他参赛者也都要面对。从提交的条目来看,没有多少人设法解决这个问题。我想说,唯一真正的解决方案是通过使用人工选择,不知何故。到目前为止,我还没有看到任何条目在其核心游戏玩法中使用它。唉,这只是一个有趣的比赛,过了一段时间,我决定不那么严格地要求游戏理念,并允许自己选择我认为可行的任何内容。我最初的想法是创造一些东西,让人类试图进化到一个新的水平,但有某种敌人试图阻止他们这样做。我有点像人类灵魂在太空中飞向巨石或太空婴儿的图像(当然都是基于2001:太空漫游),但我想不出令人信服的(阅读:严肃的)机制。博格人是我的下一个灵感来源,因为他们的整个假设非常符合进化论的主题。但是如何让它工作呢?但是如何让它工作呢?但是如何让它工作呢?但是如何让它工作呢?但是如何让它工作呢?但是如何让它工作呢?但是如何让它工作呢?但是如何让它工作呢但是工作?但是如何让它工作呢?但是如何让它工作呢?但是如何让它工作呢?但是如何让它工作呢但是工作?但是如何让它工作呢?但是如何让它工作呢?但是如何让它工作呢?但是如何让它工作呢但是工作?但是如何让它工作呢?但是如何让它工作呢?但是如何让它工作呢?但是如何让它工作呢但是工作?但是如何让它工作呢?但是如何让它工作呢?但是如何让它工作呢?但是如何让它工作呢但是工作?但是如何让它工作呢?但是如何让它工作呢?但是如何让它工作呢?但是如何让它工作呢但是工作?但是如何让它工作呢?但是如何让它工作呢?但是如何让它工作呢?但是如何让它工作呢?但是如何让它工作呢?但是如何让它工作呢?但是如何让它工作呢?但是?\"" + "chinese2": "\"它完成了,并提交了。你可以在Android和网络上玩“美味生存”。在网络上玩是有效的,但你必须模拟多次触摸才能移动桌子,这可能会有点令人困惑。我想谈的有很多。我将浏览每个主题,而不是列出典型的正确/错误列表。概念 围绕这个主题工作可能是我必须面对的最艰巨的任务之一。最初,我有一个想法,我想开发什么样的游戏,游戏玩法明智有很多敌人/演员的东西,简单的图形,可能设置在太空中,从自上而下的视图控制。我相信我可以围绕它适合任何主题。最后,游戏中像“进化”这样的主题的问题在于进化是无辅助的。随着时间的推移,它通过几个看似随机的突变发生,最合适的排列幸存下来。在我看来,这个基因汽车模拟器是面临挑战的物种实际进化的一个很好的例子。但这是游戏吗?在游戏中,您需要控制某些东西才能达到目标。这种控制违背了进化应该是什么样子。如果你允许用户选择如何进化某些东西,它就不再是进化了——它相当于智能设计,是创造论者发明的寓言,用来对抗进化论的想法。作为不可知论者和意大利面主义者,这不是以正确的方式摩擦我的东西。因此,在决定创作什么时,我最大的困惑不是我想创造什么,而是我不想要创造什么。我不想创建一个“智能设计”模拟器,并错误地称之为进化。这是一个问题,当然,其他参赛者也都要面对。从提交的条目来看,没有多少人设法解决这个问题。我想说,唯一真正的解决方案是通过使用人工选择,不知何故。到目前为止,我还没有看到任何条目在其核心游戏玩法中使用它。唉,这只是一个有趣的比赛,过了一段时间,我决定不那么严格地要求游戏理念,并允许自己选择我认为可行的任何内容。我最初的想法是创造一些东西,让人类试图进化到一个新的水平,但有某种敌人试图阻止他们这样做。我有点像人类灵魂在太空中飞向巨石或太空婴儿的图像(当然都是基于2001:太空漫游),但我想不出令人信服的(阅读:严肃的)机制。博格人是我的下一个灵感来源,因为他们的整个假设非常符合进化论的主题。但是如何让它工作呢?你是博格人,还是与博格人战斗?你是博格人,还是与博格人战斗?它完成了,并提交了。你可以在Android和网络上玩“美味生存”。在网络上玩是有效的,但你必须模拟多次触摸才能移动桌子,这可能会有点令人困惑。我想谈的有很多。我将浏览每个主题,而不是列出典型的正确/错误列表。概念围绕这个主题工作可能是我必须面对的最艰巨的任务之一。最初,我有一个想法,我想开发什么样的游戏,游戏玩法明智有很多敌人/演员的东西,简单的图形,可能设置在太空中,从自上而下的视图控制。我相信我可以围绕它适合任何主题。最后,游戏中像“进化”这样的主题的问题在于进化是无辅助的。随着时间的推移,它通过几个看似随机的突变发生,最合适的排列幸存下来。在我看来,这个基因汽车模拟器是面临挑战的物种实际进化的一个很好的例子。但这是游戏吗?在游戏中,您需要控制某些东西才能达到目标。这种控制违背了进化应该是什么样子。如果你允许用户选择如何进化某些东西,它就不再是进化了——它相当于智能设计,是创造论者发明的寓言,用来对抗进化论的想法。作为不可知论者和意大利面主义者,这不是以正确的方式摩擦我的东西。因此,在决定创作什么时,我最大的困惑不是我想创造什么,而是我不想要创造什么。我不想创建一个“智能设计”模拟器,并错误地称之为进化。这是一个问题,当然,其他参赛者也都要面对。从提交的条目来看,没有多少人设法解决这个问题。我想说,唯一真正的解决方案是通过使用人工选择,不知何故。到目前为止,我还没有看到任何条目在其核心游戏玩法中使用它。唉,这只是一个有趣的比赛,过了一段时间,我决定不那么严格地要求游戏理念,并允许自己选择我认为可行的任何内容。我最初的想法是创造一些东西,让人类试图进化到一个新的水平,但有某种敌人试图阻止他们这样做。我有点像人类灵魂在太空中飞向巨石或太空婴儿的图像(当然都是基于2001:太空漫游),但我想不出令人信服的(阅读:严肃的)机制。博格人是我的下一个灵感来源,因为他们的整个假设非常符合进化论的主题。但是如何让它工作呢?但是如何让它工作呢?但是如何让它工作呢?但是如何让它工作呢?但是如何让它工作呢?但是如何让它工作呢?但是如何让它工作呢?但是如何让它工作呢?但是如何让它工作呢?但是如何让它工作呢?但是如何让它工作呢?工作呢?它完成了,并提交了。你可以在Android和网络上玩“美味生存”。在网络上玩是有效的,但你必须模拟多次触摸才能移动桌子,这可能会有点令人困惑。我想谈的有很多。我将浏览每个主题,而不是列出典型的正确/错误列表。概念围绕这个主题工作可能是我必须面对的最艰巨的任务之一。最初,我有一个想法,我想开发什么样的游戏,游戏玩法明智有很多敌人/演员的东西,简单的图形,可能设置在太空中,从自上而下的视图控制。我相信我可以围绕它适合任何主题。最后,游戏中像“进化”这样的主题的问题在于进化是无辅助的。随着时间的推移,它通过几个看似随机的突变发生,最合适的排列幸存下来。在我看来,这个基因汽车模拟器是面临挑战的物种实际进化的一个很好的例子。但这是游戏吗?在游戏中,您需要控制某些东西才能达到目标。这种控制违背了进化应该是什么样子。如果你允许用户选择如何进化某些东西,它就不再是进化了——它相当于智能设计,是创造论者发明的寓言,用来对抗进化论的想法。作为不可知论者和意大利面主义者,这不是以正确的方式摩擦我的东西。因此,在决定创作什么时,我最大的困惑不是我想创造什么,而是我不想要创造什么。我不想创建一个“智能设计”模拟器,并错误地称之为进化。这是一个问题,当然,其他参赛者也都要面对。从提交的条目来看,没有多少人设法解决这个问题。我想说,唯一真正的解决方案是通过使用人工选择,不知何故。到目前为止,我还没有看到任何条目在其核心游戏玩法中使用它。唉,这只是一个有趣的比赛,过了一段时间,我决定不那么严格地要求游戏理念,并允许自己选择我认为可行的任何内容。我最初的想法是创造一些东西,让人类试图进化到一个新的水平,但有某种敌人试图阻止他们这样做。我有点像人类灵魂在太空中飞向巨石或太空婴儿的图像(当然都是基于2001:太空漫游),但我想不出令人信服的(阅读:严肃的)机制。博格人是我的下一个灵感来源,因为他们的整个假设非常符合进化论的主题。但是如何让它工作呢?你是博格人,还是与博格人战斗?你是博格人,还是与博格人战斗?它完成了,并提交了。你可以在Android和网络上玩“美味生存”。在网络上玩是有效的,但你必须模拟多次触摸才能移动桌子,这可能会有点令人困惑。我想谈的有很多。我将浏览每个主题,而不是列出典型的正确/错误列表。概念围绕这个主题工作可能是我必须面对的最艰巨的任务之一。最初,我有一个想法,我想开发什么样的游戏,游戏玩法明智有很多敌人/演员的东西,简单的图形,可能设置在太空中,从自上而下的视图控制。我相信我可以围绕它适合任何主题。最后,游戏中像“进化”这样的主题的问题在于进化是无辅助的。随着时间的推移,它通过几个看似随机的突变发生,最合适的排列幸存下来。在我看来,这个基因汽车模拟器是面临挑战的物种实际进化的一个很好的例子。但这是游戏吗?在游戏中,您需要控制某些东西才能达到目标。这种控制违背了进化应该是什么样子。如果你允许用户选择如何进化某些东西,它就不再是进化了——它相当于智能设计,是创造论者发明的寓言,用来对抗进化论的想法。作为不可知论者和意大利面主义者,这不是以正确的方式摩擦我的东西。因此,在决定创作什么时,我最大的困惑不是我想创造什么,而是我不想要创造什么。我不想创建一个“智能设计”模拟器,并错误地称之为进化。这是一个问题,当然,其他参赛者也都要面对。从提交的条目来看,没有多少人设法解决这个问题。我想说,唯一真正的解决方案是通过使用人工选择,不知何故。到目前为止,我还没有看到任何条目在其核心游戏玩法中使用它。唉,这只是一个有趣的比赛,过了一段时间,我决定不那么严格地要求游戏理念,并允许自己选择我认为可行的任何内容。我最初的想法是创造一些东西,让人类试图进化到一个新的水平,但有某种敌人试图阻止他们这样做。我有点像人类灵魂在太空中飞向巨石或太空婴儿的图像(当然都是基于2001:太空漫游),但我想不出令人信服的(阅读:严肃的)机制。博格人是我的下一个灵感来源,因为他们的整个假设非常符合进化论的主题。但是如何让它工作呢?但是如何让它工作呢?但是如何让它工作呢?但是如何让它工作呢?但是如何让它工作呢?但是如何让它工作呢?但是如何让它工作呢?但是如何让它工作呢但是工作?但是如何让它工作呢?但是如何让它工作呢?但是如何让它工作呢?但是如何让它工作呢但是工作?但是如何让它工作呢?但是如何让它工作呢?但是如何让它工作呢?但是如何让它工作呢但是工作?但是如何让它工作呢?但是如何让它工作呢?但是如何让它工作呢?但是如何让它工作呢但是工作?但是如何让它工作呢?但是如何让它工作呢?但是如何让它工作呢?但是如何让它工作呢但是工作?但是如何让它工作呢?但是如何让它工作呢?但是如何让它工作呢?但是如何让它工作呢但是工作?但是如何让它工作呢?但是如何让它工作呢?但是如何让它工作呢?但是如何让它工作呢?但是如何让它工作呢?但是如何让它工作呢?但是如何让它工作呢?但是?\"", + "chinese3": "\"它完成了,并提交了。你可以在Android和网络上玩“美味生存”。在网络上玩是有效的,但你必须模拟多次触摸才能移动桌子,这可能会有点令人困惑。我想谈的有很多。我将浏览每个主题,而不是列出典型的正确/错误列表。概念 围绕这个主题工作可能是我必须面对的最艰巨的任务之一。最初,我有一个想法,我想开发什么样的游戏,游戏玩法明智有很多敌人/演员的东西,简单的图形,可能设置在太空中,从自上而下的视图控制。我相信我可以围绕它适合任何主题。最后,游戏中像“进化”这样的主题的问题在于进化是无辅助的。随着时间的推移,它通过几个看似随机的突变发生,最合适的排列幸存下来。在我看来,这个基因汽车模拟器是面临挑战的物种实际进化的一个很好的例子。但这是游戏吗?在游戏中,您需要控制某些东西才能达到目标。这种控制违背了进化应该是什么样子。如果你允许用户选择如何进化某些东西,它就不再是进化了——它相当于智能设计,是创造论者发明的寓言,用来对抗进化论的想法。作为不可知论者和意大利面主义者,这不是以正确的方式摩擦我的东西。因此,在决定创作什么时,我最大的困惑不是我想创造什么,而是我不想要创造什么。我不想创建一个“智能设计”模拟器,并错误地称之为进化。这是一个问题,当然,其他参赛者也都要面对。从提交的条目来看,没有多少人设法解决这个问题。我想说,唯一真正的解决方案是通过使用人工选择,不知何故。到目前为止,我还没有看到任何条目在其核心游戏玩法中使用它。唉,这只是一个有趣的比赛,过了一段时间,我决定不那么严格地要求游戏理念,并允许自己选择我认为可行的任何内容。我最初的想法是创造一些东西,让人类试图进化到一个新的水平,但有某种敌人试图阻止他们这样做。我有点像人类灵魂在太空中飞向巨石或太空婴儿的图像(当然都是基于2001:太空漫游),但我想不出令人信服的(阅读:严肃的)机制。博格人是我的下一个灵感来源,因为他们的整个假设非常符合进化论的主题。但是如何让它工作呢?你是博格人,还是与博格人战斗?你是博格人,还是与博格人战斗?它完成了,并提交了。你可以在Android和网络上玩“美味生存”。在网络上玩是有效的,但你必须模拟多次触摸才能移动桌子,这可能会有点令人困惑。我想谈的有很多。我将浏览每个主题,而不是列出典型的正确/错误列表。概念围绕这个主题工作可能是我必须面对的最艰巨的任务之一。最初,我有一个想法,我想开发什么样的游戏,游戏玩法明智有很多敌人/演员的东西,简单的图形,可能设置在太空中,从自上而下的视图控制。我相信我可以围绕它适合任何主题。最后,游戏中像“进化”这样的主题的问题在于进化是无辅助的。随着时间的推移,它通过几个看似随机的突变发生,最合适的排列幸存下来。在我看来,这个基因汽车模拟器是面临挑战的物种实际进化的一个很好的例子。但这是游戏吗?在游戏中,您需要控制某些东西才能达到目标。这种控制违背了进化应该是什么样子。如果你允许用户选择如何进化某些东西,它就不再是进化了——它相当于智能设计,是创造论者发明的寓言,用来对抗进化论的想法。作为不可知论者和意大利面主义者,这不是以正确的方式摩擦我的东西。因此,在决定创作什么时,我最大的困惑不是我想创造什么,而是我不想要创造什么。我不想创建一个“智能设计”模拟器,并错误地称之为进化。这是一个问题,当然,其他参赛者也都要面对。从提交的条目来看,没有多少人设法解决这个问题。我想说,唯一真正的解决方案是通过使用人工选择,不知何故。到目前为止,我还没有看到任何条目在其核心游戏玩法中使用它。唉,这只是一个有趣的比赛,过了一段时间,我决定不那么严格地要求游戏理念,并允许自己选择我认为可行的任何内容。我最初的想法是创造一些东西,让人类试图进化到一个新的水平,但有某种敌人试图阻止他们这样做。我有点像人类灵魂在太空中飞向巨石或太空婴儿的图像(当然都是基于2001:太空漫游),但我想不出令人信服的(阅读:严肃的)机制。博格人是我的下一个灵感来源,因为他们的整个假设非常符合进化论的主题。但是如何让它工作呢?但是如何让它工作呢?但是如何让它工作呢?但是如何让它工作呢?但是如何让它工作呢?但是如何让它工作呢?但是如何让它工作呢?但是如何让它工作呢?但是如何让它工作呢?但是如何让它工作呢?但是如何让它工作呢?工作呢?它完成了,并提交了。你可以在Android和网络上玩“美味生存”。在网络上玩是有效的,但你必须模拟多次触摸才能移动桌子,这可能会有点令人困惑。我想谈的有很多。我将浏览每个主题,而不是列出典型的正确/错误列表。概念围绕这个主题工作可能是我必须面对的最艰巨的任务之一。最初,我有一个想法,我想开发什么样的游戏,游戏玩法明智有很多敌人/演员的东西,简单的图形,可能设置在太空中,从自上而下的视图控制。我相信我可以围绕它适合任何主题。最后,游戏中像“进化”这样的主题的问题在于进化是无辅助的。随着时间的推移,它通过几个看似随机的突变发生,最合适的排列幸存下来。在我看来,这个基因汽车模拟器是面临挑战的物种实际进化的一个很好的例子。但这是游戏吗?在游戏中,您需要控制某些东西才能达到目标。这种控制违背了进化应该是什么样子。如果你允许用户选择如何进化某些东西,它就不再是进化了——它相当于智能设计,是创造论者发明的寓言,用来对抗进化论的想法。作为不可知论者和意大利面主义者,这不是以正确的方式摩擦我的东西。因此,在决定创作什么时,我最大的困惑不是我想创造什么,而是我不想要创造什么。我不想创建一个“智能设计”模拟器,并错误地称之为进化。这是一个问题,当然,其他参赛者也都要面对。从提交的条目来看,没有多少人设法解决这个问题。我想说,唯一真正的解决方案是通过使用人工选择,不知何故。到目前为止,我还没有看到任何条目在其核心游戏玩法中使用它。唉,这只是一个有趣的比赛,过了一段时间,我决定不那么严格地要求游戏理念,并允许自己选择我认为可行的任何内容。我最初的想法是创造一些东西,让人类试图进化到一个新的水平,但有某种敌人试图阻止他们这样做。我有点像人类灵魂在太空中飞向巨石或太空婴儿的图像(当然都是基于2001:太空漫游),但我想不出令人信服的(阅读:严肃的)机制。博格人是我的下一个灵感来源,因为他们的整个假设非常符合进化论的主题。但是如何让它工作呢?你是博格人,还是与博格人战斗?你是博格人,还是与博格人战斗?它完成了,并提交了。你可以在Android和网络上玩“美味生存”。在网络上玩是有效的,但你必须模拟多次触摸才能移动桌子,这可能会有点令人困惑。我想谈的有很多。我将浏览每个主题,而不是列出典型的正确/错误列表。概念围绕这个主题工作可能是我必须面对的最艰巨的任务之一。最初,我有一个想法,我想开发什么样的游戏,游戏玩法明智有很多敌人/演员的东西,简单的图形,可能设置在太空中,从自上而下的视图控制。我相信我可以围绕它适合任何主题。最后,游戏中像“进化”这样的主题的问题在于进化是无辅助的。随着时间的推移,它通过几个看似随机的突变发生,最合适的排列幸存下来。在我看来,这个基因汽车模拟器是面临挑战的物种实际进化的一个很好的例子。但这是游戏吗?在游戏中,您需要控制某些东西才能达到目标。这种控制违背了进化应该是什么样子。如果你允许用户选择如何进化某些东西,它就不再是进化了——它相当于智能设计,是创造论者发明的寓言,用来对抗进化论的想法。作为不可知论者和意大利面主义者,这不是以正确的方式摩擦我的东西。\"" } } }