Web server (#94)
* Init App * init server.py (#93) * init server.py * Update requirements.txt Add requirement Co-authored-by: auau <auau@test.com> Co-authored-by: babysor00 <babysor00@gmail.com> * Run web.py! Run web.py! Co-authored-by: balala <Ozgay@users.noreply.github.com> Co-authored-by: auau <auau@test.com>pull/96/head^2
parent
4178416385
commit
ddd478c0ad
@ -0,0 +1,11 @@
|
|||||||
|
from web import webApp
|
||||||
|
from gevent import pywsgi as wsgi
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
app = webApp()
|
||||||
|
host = app.config.get("HOST")
|
||||||
|
port = app.config.get("PORT")
|
||||||
|
print(f"Web server: http://{host}:{port}")
|
||||||
|
server = wsgi.WSGIServer((host, port), app)
|
||||||
|
server.serve_forever()
|
@ -0,0 +1,10 @@
|
|||||||
|
|
||||||
|
FROM python:3.7
|
||||||
|
|
||||||
|
RUN pip install gevent uwsgi flask
|
||||||
|
|
||||||
|
COPY app.py /app.py
|
||||||
|
|
||||||
|
EXPOSE 3000
|
||||||
|
|
||||||
|
ENTRYPOINT ["uwsgi", "--http", ":3000", "--master", "--module", "app:app"]
|
@ -0,0 +1,167 @@
|
|||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
from gevent import pywsgi as wsgi
|
||||||
|
from flask import Flask, jsonify, Response, request, render_template
|
||||||
|
from synthesizer.inference import Synthesizer
|
||||||
|
from encoder import inference as encoder
|
||||||
|
from vocoder.hifigan import inference as gan_vocoder
|
||||||
|
from vocoder.wavernn import inference as rnn_vocoder
|
||||||
|
import numpy as np
|
||||||
|
import re
|
||||||
|
from scipy.io.wavfile import write, read
|
||||||
|
import io
|
||||||
|
import base64
|
||||||
|
from flask_cors import CORS
|
||||||
|
from flask_wtf import CSRFProtect
|
||||||
|
|
||||||
|
def webApp():
|
||||||
|
# Init and load config
|
||||||
|
app = Flask(__name__, instance_relative_config=True)
|
||||||
|
|
||||||
|
app.config.from_object("web.config.default")
|
||||||
|
|
||||||
|
CORS(app) #允许跨域,注释掉此行则禁止跨域请求
|
||||||
|
csrf = CSRFProtect(app)
|
||||||
|
csrf.init_app(app)
|
||||||
|
# API For Non-Trainer
|
||||||
|
# 1. list sample audio files
|
||||||
|
# 2. record / upload / select audio files
|
||||||
|
# 3. load melspetron of audio
|
||||||
|
# 4. inference by audio + text + models(encoder, vocoder, synthesizer)
|
||||||
|
# 5. export result
|
||||||
|
audio_samples = []
|
||||||
|
AUDIO_SAMPLES_DIR = app.config.get("AUDIO_SAMPLES_DIR")
|
||||||
|
if os.path.isdir(AUDIO_SAMPLES_DIR):
|
||||||
|
audio_samples = list(Path(AUDIO_SAMPLES_DIR).glob("*.wav"))
|
||||||
|
print("Loaded samples: " + str(len(audio_samples)))
|
||||||
|
# enc_models_dir = "encoder/saved_models"
|
||||||
|
# voc_models_di = "vocoder/saved_models"
|
||||||
|
# encoders = list(Path(enc_models_dir).glob("*.pt"))
|
||||||
|
# vocoders = list(Path(voc_models_di).glob("**/*.pt"))
|
||||||
|
syn_models_dirt = "synthesizer/saved_models"
|
||||||
|
synthesizers = list(Path(syn_models_dirt).glob("**/*.pt"))
|
||||||
|
# print("Loaded encoder models: " + str(len(encoders)))
|
||||||
|
# print("Loaded vocoder models: " + str(len(vocoders)))
|
||||||
|
print("Loaded synthesizer models: " + str(len(synthesizers)))
|
||||||
|
synthesizers_cache = {}
|
||||||
|
encoder.load_model(Path("encoder/saved_models/pretrained.pt"))
|
||||||
|
gan_vocoder.load_model(Path("vocoder/saved_models/pretrained/g_hifigan.pt"))
|
||||||
|
|
||||||
|
# TODO: move to utils
|
||||||
|
def generate(wav_path):
|
||||||
|
with open(wav_path, "rb") as fwav:
|
||||||
|
data = fwav.read(1024)
|
||||||
|
while data:
|
||||||
|
yield data
|
||||||
|
data = fwav.read(1024)
|
||||||
|
|
||||||
|
@app.route("/api/audios", methods=["GET"])
|
||||||
|
def audios():
|
||||||
|
return jsonify(
|
||||||
|
{"data": list(a.name for a in audio_samples), "total": len(audio_samples)}
|
||||||
|
)
|
||||||
|
|
||||||
|
@app.route("/api/audios/<name>", methods=["GET"])
|
||||||
|
def audio_play(name):
|
||||||
|
return Response(generate(AUDIO_SAMPLES_DIR + name), mimetype="audio/x-wav")
|
||||||
|
|
||||||
|
@app.route("/api/models", methods=["GET"])
|
||||||
|
def models():
|
||||||
|
return jsonify(
|
||||||
|
{
|
||||||
|
# "encoder": list(e.name for e in encoders),
|
||||||
|
# "vocoder": list(e.name for e in vocoders),
|
||||||
|
"synthesizers":
|
||||||
|
list({"name": e.name, "path": str(e)} for e in synthesizers),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
def pcm2float(sig, dtype='float32'):
|
||||||
|
"""Convert PCM signal to floating point with a range from -1 to 1.
|
||||||
|
Use dtype='float32' for single precision.
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
sig : array_like
|
||||||
|
Input array, must have integral type.
|
||||||
|
dtype : data type, optional
|
||||||
|
Desired (floating point) data type.
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
numpy.ndarray
|
||||||
|
Normalized floating point data.
|
||||||
|
See Also
|
||||||
|
--------
|
||||||
|
float2pcm, dtype
|
||||||
|
"""
|
||||||
|
sig = np.asarray(sig)
|
||||||
|
if sig.dtype.kind not in 'iu':
|
||||||
|
raise TypeError("'sig' must be an array of integers")
|
||||||
|
dtype = np.dtype(dtype)
|
||||||
|
if dtype.kind != 'f':
|
||||||
|
raise TypeError("'dtype' must be a floating point type")
|
||||||
|
|
||||||
|
i = np.iinfo(sig.dtype)
|
||||||
|
abs_max = 2 ** (i.bits - 1)
|
||||||
|
offset = i.min + abs_max
|
||||||
|
return (sig.astype(dtype) - offset) / abs_max
|
||||||
|
|
||||||
|
# Cache for synthesizer
|
||||||
|
@csrf.exempt
|
||||||
|
@app.route("/api/synthesize", methods=["POST"])
|
||||||
|
def synthesize():
|
||||||
|
# TODO Implementation with json to support more platform
|
||||||
|
|
||||||
|
# Load synthesizer
|
||||||
|
if "synt_path" in request.form:
|
||||||
|
synt_path = request.form["synt_path"]
|
||||||
|
else:
|
||||||
|
synt_path = synthesizers[0]
|
||||||
|
print("NO synthsizer is specified, try default first one.")
|
||||||
|
if synthesizers_cache.get(synt_path) is None:
|
||||||
|
current_synt = Synthesizer(Path(synt_path))
|
||||||
|
synthesizers_cache[synt_path] = current_synt
|
||||||
|
else:
|
||||||
|
current_synt = synthesizers_cache[synt_path]
|
||||||
|
print("using synthesizer model: " + str(synt_path))
|
||||||
|
# Load input wav
|
||||||
|
wav_base64 = request.form["upfile_b64"]
|
||||||
|
wav = base64.b64decode(bytes(wav_base64, 'utf-8'))
|
||||||
|
wav = pcm2float(np.frombuffer(wav, dtype=np.int16), dtype=np.float32)
|
||||||
|
encoder_wav = encoder.preprocess_wav(wav, 16000)
|
||||||
|
embed, _, _ = encoder.embed_utterance(encoder_wav, return_partials=True)
|
||||||
|
|
||||||
|
# Load input text
|
||||||
|
texts = request.form["text"].split("\n")
|
||||||
|
punctuation = '!,。、,' # punctuate and split/clean text
|
||||||
|
processed_texts = []
|
||||||
|
for text in texts:
|
||||||
|
for processed_text in re.sub(r'[{}]+'.format(punctuation), '\n', text).split('\n'):
|
||||||
|
if processed_text:
|
||||||
|
processed_texts.append(processed_text.strip())
|
||||||
|
texts = processed_texts
|
||||||
|
|
||||||
|
# synthesize and vocode
|
||||||
|
embeds = [embed] * len(texts)
|
||||||
|
specs = current_synt.synthesize_spectrograms(texts, embeds)
|
||||||
|
spec = np.concatenate(specs, axis=1)
|
||||||
|
wav = gan_vocoder.infer_waveform(spec)
|
||||||
|
|
||||||
|
# Return cooked wav
|
||||||
|
out = io.BytesIO()
|
||||||
|
write(out, Synthesizer.sample_rate, wav)
|
||||||
|
return Response(out, mimetype="audio/wav")
|
||||||
|
|
||||||
|
@app.route('/', methods=['GET'])
|
||||||
|
def index():
|
||||||
|
return render_template("index.html")
|
||||||
|
|
||||||
|
host = app.config.get("HOST")
|
||||||
|
port = app.config.get("PORT")
|
||||||
|
print(f"Web server: http://{host}:{port}")
|
||||||
|
server = wsgi.WSGIServer((host, port), app)
|
||||||
|
server.serve_forever()
|
||||||
|
|
||||||
|
return app
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
webApp()
|
@ -0,0 +1,7 @@
|
|||||||
|
AUDIO_SAMPLES_DIR = 'samples\\'
|
||||||
|
DEVICE = '0'
|
||||||
|
HOST = 'localhost'
|
||||||
|
PORT = 8080
|
||||||
|
MAX_CONTENT_PATH =1024 * 1024 * 4 # mp3文件大小限定不能超过4M
|
||||||
|
SECRET_KEY = "mockingbird_key"
|
||||||
|
WTF_CSRF_SECRET_KEY = "mockingbird_key"
|
File diff suppressed because one or more lines are too long
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
Loading…
Reference in New Issue