Coverage for Python_files/learner.py: 0%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1"""
2The aim of this code is to estimate the cascade's parameters.
3"""
7import argparse # To parse command line arguments
8import json # To parse and dump JSON
9from kafka import KafkaConsumer # Import Kafka consumer
10from kafka import KafkaProducer # Import Kafka producer
11import os
12import numpy as np
13from sklearn.ensemble import RandomForestRegressor
15import hawkes_tools as HT
16import logger
20if __name__=="__main__" :
22 logger = logger.get_logger('learner', broker_list="localhost::9092",debug=True)
24 ################################################
25 ####### Kafka Part ########
26 ################################################
28 topic_reading="samples"
29 topic_writing="models"
32 ## default value without typing anything in the terminal
33 parser = argparse.ArgumentParser()
34 parser.add_argument('--broker-list', type=str, help="the broker list", default="localhost:9092")
35 args = parser.parse_args() # Parse arguments
38 consumer = KafkaConsumer(topic_reading, # Topic name
39 bootstrap_servers = args.broker_list, # List of brokers passed from the command line
40 value_deserializer=lambda v: json.loads(v.decode('utf-8')), # How to deserialize the value from a binary buffer
41 key_deserializer= lambda v: v.decode() # How to deserialize the key (if any)
42 )
44 producer = KafkaProducer(
45 bootstrap_servers = args.broker_list, # List of brokers passed from the command line
46 value_serializer=lambda v: json.dumps(v).encode('utf-8'), # How to serialize the value to a binary buffer
47 key_serializer=str.encode # How to serialize the key
48 )
50 ################################################
51 ####### Stats part ########
52 ################################################
54 logger.info("Start reading in the samples topic...")
57 for msg in consumer :
58 # I'll construct a cascade object thanks to msg
59 cid=msg.value["cid"]
60 X= msg.value["X"]
61 # TODO data set en append le X et W et si longueur du dataset dépasse un seuil on train
62 model = RandomForestRegressor.fit(X,msg.value["W"])
64 send ={
65 'type': 'parameters',
66 'n_obs' : msg.value["T_obs"],
67 }
68 logger.info(f"Sending estimated parameter for {cid}...")
69 producer.send(topic_writing, key = msg.value['T_obs'], value = send)